pgwire/
pg_protocol.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::error_or_notice::Severity;
46use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
47use crate::net::AddressRef;
48use crate::pg_extended::ResultCache;
49use crate::pg_message::{
50    BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
51    FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
52    FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
53};
54use crate::pg_server::{Session, SessionManager, UserAuthenticator};
55use crate::types::Format;
56
57/// Truncates query log if it's longer than `RW_QUERY_LOG_TRUNCATE_LEN`, to avoid log file being too
58/// large.
59static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
60    LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
61        Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
62        _ => 65536,
63    });
64
65tokio::task_local! {
66    /// The current session. Concrete type is erased for different session implementations.
67    pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
68}
69
70/// The state machine for each psql connection.
71/// Read pg messages from tcp stream and write results back.
72pub struct PgProtocol<S, SM>
73where
74    SM: SessionManager,
75{
76    /// Used for write/read pg messages.
77    stream: PgStream<S>,
78    /// Current states of pg connection.
79    state: PgProtocolState,
80    /// Whether the connection is terminated.
81    is_terminate: bool,
82
83    session_mgr: Arc<SM>,
84    session: Option<Arc<SM::Session>>,
85
86    result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
87    unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
88    prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
89    unnamed_portal: Option<<SM::Session as Session>::Portal>,
90    portal_store: HashMap<String, <SM::Session as Session>::Portal>,
91    // Used to store the dependency of portal and prepare statement.
92    // When we close a prepare statement, we need to close all the portals that depend on it.
93    statement_portal_dependency: HashMap<String, Vec<String>>,
94
95    // Used for ssl connection.
96    // If None, not expected to build ssl connection (panic).
97    tls_context: Option<SslContext>,
98
99    // TLS configuration including SSL enforcement setting
100    tls_config: Option<TlsConfig>,
101
102    // Used in extended query protocol. When encounter error in extended query, we need to ignore
103    // the following message util sync message.
104    ignore_util_sync: bool,
105
106    // Client Address
107    peer_addr: AddressRef,
108
109    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
110    message_memory_manager: MessageMemoryManagerRef,
111}
112
113/// Configures TLS encryption for connections.
114#[derive(Debug, Clone)]
115pub struct TlsConfig {
116    /// The path to the TLS certificate.
117    pub cert: String,
118    /// The path to the TLS key.
119    pub key: String,
120    /// Whether to enforce SSL connections (reject non-SSL clients).
121    pub enforce_ssl: bool,
122}
123
124impl TlsConfig {
125    pub fn new_default() -> Option<Self> {
126        let cert = std::env::var("RW_SSL_CERT").ok()?;
127        let key = std::env::var("RW_SSL_KEY").ok()?;
128        let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
129        tracing::info!(
130            "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
131            cert,
132            key,
133            enforce_ssl
134        );
135        Some(Self {
136            cert,
137            key,
138            enforce_ssl,
139        })
140    }
141}
142
143impl<S, SM> Drop for PgProtocol<S, SM>
144where
145    SM: SessionManager,
146{
147    fn drop(&mut self) {
148        if let Some(session) = &self.session {
149            // Clear the session in session manager.
150            self.session_mgr.end_session(session);
151        }
152    }
153}
154
155/// States flow happened from top to down.
156enum PgProtocolState {
157    Startup,
158    Regular,
159}
160
161/// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations).
162///
163/// PG protocol strings are always C strings.
164pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
165    let without_null = if b.last() == Some(&0) {
166        &b[..b.len() - 1]
167    } else {
168        &b[..]
169    };
170    std::str::from_utf8(without_null)
171}
172
173fn record_sql_in_current_span(
174    sql: &str,
175    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
176) -> String {
177    let mut span = tracing::Span::current();
178    record_sql_in_span(sql, redact_sql_option_keywords, &mut span)
179}
180
181/// Record `sql` in the current tracing span.
182fn record_sql_in_span(
183    sql: &str,
184    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
185    span: &mut tracing::Span,
186) -> String {
187    let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
188        && !keywords.is_empty()
189    {
190        redact_sql(sql, keywords)
191    } else {
192        sql.to_owned()
193    };
194    let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
195    span.record("sql", tracing::field::display(&truncated));
196    truncated.to_string()
197}
198
199/// Redacts SQL options. Data in DML is not redacted.
200fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
201    match Parser::parse_sql(sql) {
202        Ok(sqls) => sqls
203            .into_iter()
204            .map(|sql| sql.to_redacted_string(keywords.clone()))
205            .join(";"),
206        Err(_) => sql.to_owned(),
207    }
208}
209
210#[derive(Clone)]
211pub struct ConnectionContext {
212    pub tls_config: Option<TlsConfig>,
213    pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
214    pub message_memory_manager: MessageMemoryManagerRef,
215}
216
217impl<S, SM> PgProtocol<S, SM>
218where
219    S: PgByteStream,
220    SM: SessionManager,
221{
222    pub fn new(
223        stream: S,
224        session_mgr: Arc<SM>,
225        peer_addr: AddressRef,
226        context: ConnectionContext,
227    ) -> Self {
228        let ConnectionContext {
229            tls_config,
230            redact_sql_option_keywords,
231            message_memory_manager,
232        } = context;
233        Self {
234            stream: PgStream::new(stream),
235            is_terminate: false,
236            state: PgProtocolState::Startup,
237            session_mgr,
238            session: None,
239            tls_context: tls_config
240                .as_ref()
241                .and_then(|e| build_ssl_ctx_from_config(e).ok()),
242            tls_config,
243            result_cache: Default::default(),
244            unnamed_prepare_statement: Default::default(),
245            prepare_statement_store: Default::default(),
246            unnamed_portal: Default::default(),
247            portal_store: Default::default(),
248            statement_portal_dependency: Default::default(),
249            ignore_util_sync: false,
250            peer_addr,
251            redact_sql_option_keywords,
252            message_memory_manager,
253        }
254    }
255
256    /// Run the protocol to serve the connection.
257    pub async fn run(&mut self) {
258        let mut notice_fut = None;
259
260        loop {
261            // Once a session is present, create a future to subscribe and send notices asynchronously.
262            if notice_fut.is_none()
263                && let Some(session) = self.session.clone()
264            {
265                let mut stream = self.stream.clone();
266                notice_fut = Some(Box::pin(async move {
267                    loop {
268                        let notice = session.next_notice().await;
269                        if let Err(e) = stream.write(BeMessage::NoticeResponse(&notice)).await {
270                            tracing::error!(error = %e.as_report(), notice, "failed to send notice");
271                        }
272                    }
273                }));
274            }
275
276            // Read and process messages.
277            let process = std::pin::pin!(async {
278                let (msg, _memory_guard) = match self.read_message().await {
279                    Ok(msg) => msg,
280                    Err(e) => {
281                        tracing::error!(error = %e.as_report(), "error when reading message");
282                        return true; // terminate the connection
283                    }
284                };
285                tracing::trace!(?msg, "received message");
286                self.process(msg).await
287            });
288
289            let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
290                tokio::select! {
291                    _ = notice_fut => unreachable!(),
292                    terminated = process => terminated,
293                }
294            } else {
295                process.await
296            };
297
298            if terminated {
299                break;
300            }
301        }
302    }
303
304    /// Processes one message. Returns true if the connection is terminated.
305    pub async fn process(&mut self, msg: FeMessage) -> bool {
306        self.do_process(msg).await.is_none() || self.is_terminate
307    }
308
309    /// The root tracing span for processing a message. The target of the span is
310    /// [`PGWIRE_ROOT_SPAN_TARGET`].
311    ///
312    /// This is used to provide context for the (slow) query logs and traces.
313    ///
314    /// The span is only effective if there's a current session and the message is
315    /// query-related. Otherwise, `Span::none()` is returned.
316    fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
317        let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
318            return tracing::Span::none();
319        };
320
321        let mode = match msg {
322            FeMessage::Query(_) => "simple query",
323            FeMessage::Parse(_) => "extended query parse",
324            FeMessage::Execute(_) => "extended query execute",
325            _ => return tracing::Span::none(),
326        };
327
328        let mut span = tracing::info_span!(
329            target: PGWIRE_ROOT_SPAN_TARGET,
330            "handle_query",
331            mode,
332            session_id,
333            sql = tracing::field::Empty,
334        );
335        if let Ok(sql) = msg.get_sql()
336            && let Some(sql) = sql
337        {
338            record_sql_in_span(sql, self.redact_sql_option_keywords.clone(), &mut span);
339        }
340        span
341    }
342
343    /// Return type `Option<()>` is essentially a bool, but allows `?` for early return.
344    /// - `None` means to terminate the current connection
345    /// - `Some(())` means to continue processing the next message
346    async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
347        let span = self.root_span_for_msg(&msg);
348        let weak_session = self
349            .session
350            .as_ref()
351            .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
352
353        // Processing the message itself.
354        //
355        // Note: pin the future to avoid stack overflow as we'll wrap it multiple times
356        // in the following code.
357        let fut = Box::pin(self.do_process_inner(msg));
358
359        // Set the current session as the context when processing the message, if exists.
360        let fut = async move {
361            if let Some(session) = weak_session {
362                CURRENT_SESSION.scope(session, fut).await
363            } else {
364                fut.await
365            }
366        };
367
368        // Catch unwind.
369        let fut = async move {
370            AssertUnwindSafe(fut)
371                .rw_catch_unwind()
372                .await
373                .unwrap_or_else(|payload| {
374                    Err(PsqlError::Panic(
375                        panic_message::panic_message(&payload).to_owned(),
376                    ))
377                })
378        };
379
380        // Slow query log.
381        let fut = async move {
382            let period = *SLOW_QUERY_LOG_PERIOD;
383            let mut fut = std::pin::pin!(fut);
384            let mut elapsed = Duration::ZERO;
385
386            // Report the SQL in the log periodically if the query is slow.
387            loop {
388                match tokio::time::timeout(period, &mut fut).await {
389                    Ok(result) => break result,
390                    Err(_) => {
391                        elapsed += period;
392                        tracing::info!(
393                            target: PGWIRE_SLOW_QUERY_LOG,
394                            elapsed = %format_args!("{}ms", elapsed.as_millis()),
395                            "slow query"
396                        );
397                    }
398                }
399            }
400        };
401
402        // Query log.
403        let fut = async move {
404            if !tracing::Span::current().is_none() {
405                tracing::info!(
406                    target: PGWIRE_QUERY_LOG,
407                    status = "started",
408                );
409            }
410
411            let start = Instant::now();
412            let result = fut.await;
413            let elapsed = start.elapsed();
414
415            // Always log if an error occurs.
416            // Note: all messages will be processed through this code path, making it the
417            //       only necessary place to log errors.
418            if let Err(error) = &result {
419                if cfg!(debug_assertions) && !Deployment::current().is_ci() {
420                    // For local debugging, we print the error with backtrace.
421                    // It's useful only when:
422                    // - no additional context is added to the error
423                    // - backtrace is captured in the error
424                    // - backtrace is not printed in the middle
425                    tracing::error!(error = ?error.as_report(), "error when process message");
426                } else {
427                    tracing::error!(error = %error.as_report(), "error when process message");
428                }
429            }
430
431            // Log to optionally-enabled target `PGWIRE_QUERY_LOG`.
432            // Only log if we're currently in a tracing span set in `span_for_msg`.
433            if !tracing::Span::current().is_none() {
434                tracing::info!(
435                    target: PGWIRE_QUERY_LOG,
436                    status = if result.is_ok() { "ok" } else { "err" },
437                    time = %format_args!("{}ms", elapsed.as_millis()),
438                );
439            }
440
441            result
442        };
443
444        // Tracing span.
445        let fut = fut.instrument(span);
446
447        // Execute the future and handle the error.
448        match fut.await {
449            Ok(()) => Some(()),
450            Err(e) => {
451                match e {
452                    PsqlError::IoError(io_err) => {
453                        if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
454                            return None;
455                        }
456                    }
457
458                    PsqlError::SslError(_) => {
459                        // For ssl error, because the stream has already been consumed, so there is
460                        // no way to write more message.
461                        return None;
462                    }
463
464                    PsqlError::StartupError(_) | PsqlError::PasswordError => {
465                        self.stream
466                            .write_no_flush(BeMessage::ErrorResponse {
467                                error: &e,
468                                // At this time we're not in a session, use compact error message for
469                                // better alignment with Postgres' UI.
470                                pretty: false,
471                                severity: Some(Severity::Fatal),
472                            })
473                            .ok()?;
474                        let _ = self.stream.flush().await;
475                        return None;
476                    }
477
478                    PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
479                        self.stream
480                            .write_no_flush(BeMessage::ErrorResponse {
481                                error: &e,
482                                pretty: true,
483                                severity: None,
484                            })
485                            .ok()?;
486                        self.ready_for_query().ok()?;
487                    }
488
489                    PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
490                        self.stream
491                            .write_no_flush(BeMessage::ErrorResponse {
492                                error: &e,
493                                pretty: true,
494                                severity: None,
495                            })
496                            .ok()?;
497                        let _ = self.stream.flush().await;
498
499                        // 1. Catching the panic during message processing may leave the session in an
500                        // inconsistent state. We forcefully close the connection (then end the
501                        // session) here for safety.
502                        // 2. Idle in transaction timeout should also close the connection.
503                        return None;
504                    }
505
506                    PsqlError::Uncategorized(_)
507                    | PsqlError::ExtendedPrepareError(_)
508                    | PsqlError::ExtendedExecuteError(_) => {
509                        self.stream
510                            .write_no_flush(BeMessage::ErrorResponse {
511                                error: &e,
512                                pretty: true,
513                                severity: None,
514                            })
515                            .ok()?;
516                    }
517                }
518                let _ = self.stream.flush().await;
519                Some(())
520            }
521        }
522    }
523
524    async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
525        // Ignore util sync message.
526        if self.ignore_util_sync {
527            if let FeMessage::Sync = msg {
528            } else {
529                tracing::trace!("ignore message {:?} until sync.", msg);
530                return Ok(());
531            }
532        }
533
534        match msg {
535            FeMessage::Gss => self.process_gss_msg().await?,
536            FeMessage::Ssl => self.process_ssl_msg().await?,
537            FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
538            FeMessage::Password(msg) => self.process_password_msg(msg).await?,
539            FeMessage::Query(query_msg) => {
540                let sql = Arc::from(query_msg.get_sql()?);
541                // The process_query_msg can be slow. Release potential large FeQueryMessage early.
542                drop(query_msg);
543                self.process_query_msg(sql).await?
544            }
545            FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
546            FeMessage::Terminate => self.process_terminate(),
547            FeMessage::Parse(m) => {
548                if let Err(err) = self.process_parse_msg(m).await {
549                    self.ignore_util_sync = true;
550                    return Err(err);
551                }
552            }
553            FeMessage::Bind(m) => {
554                if let Err(err) = self.process_bind_msg(m) {
555                    self.ignore_util_sync = true;
556                    return Err(err);
557                }
558            }
559            FeMessage::Execute(m) => {
560                if let Err(err) = self.process_execute_msg(m).await {
561                    self.ignore_util_sync = true;
562                    return Err(err);
563                }
564            }
565            FeMessage::Describe(m) => {
566                if let Err(err) = self.process_describe_msg(m) {
567                    self.ignore_util_sync = true;
568                    return Err(err);
569                }
570            }
571            FeMessage::Sync => {
572                self.ignore_util_sync = false;
573                self.ready_for_query()?
574            }
575            FeMessage::Close(m) => {
576                if let Err(err) = self.process_close_msg(m) {
577                    self.ignore_util_sync = true;
578                    return Err(err);
579                }
580            }
581            FeMessage::Flush => {
582                if let Err(err) = self.stream.flush().await {
583                    self.ignore_util_sync = true;
584                    return Err(err.into());
585                }
586            }
587            FeMessage::HealthCheck => self.process_health_check(),
588            FeMessage::ServerThrottle(reason) => match reason {
589                ServerThrottleReason::TooLargeMessage => {
590                    return Err(PsqlError::ServerThrottle(format!(
591                        "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
592                        self.message_memory_manager.max_filter_bytes
593                    )));
594                }
595                ServerThrottleReason::TooManyMemoryUsage => {
596                    return Err(PsqlError::ServerThrottle(format!(
597                        "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
598                        self.message_memory_manager.max_running_bytes
599                    )));
600                }
601            },
602        }
603        self.stream.flush().await?;
604        Ok(())
605    }
606
607    pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
608        match self.state {
609            PgProtocolState::Startup => self
610                .stream
611                .read_startup()
612                .await
613                .map(|message: FeMessage| (message, None)),
614            PgProtocolState::Regular => {
615                self.stream.read_header().await?;
616                let guard = if let Some(ref header) = self.stream.read_header {
617                    let payload_len = std::cmp::max(header.payload_len, 0) as u64;
618                    let (reason, guard) = self.message_memory_manager.add(payload_len);
619                    if let Some(reason) = reason {
620                        // Release the memory ASAP.
621                        drop(guard);
622                        self.stream.skip_body().await?;
623                        return Ok((FeMessage::ServerThrottle(reason), None));
624                    }
625                    guard
626                } else {
627                    None
628                };
629                let message = self.stream.read_body().await?;
630                Ok((message, guard))
631            }
632        }
633    }
634
635    /// Writes a `ReadyForQuery` message to the client without flushing.
636    fn ready_for_query(&mut self) -> io::Result<()> {
637        self.stream.write_no_flush(BeMessage::ReadyForQuery(
638            self.session
639                .as_ref()
640                .map(|s| s.transaction_status())
641                .unwrap_or(TransactionStatus::Idle),
642        ))
643    }
644
645    async fn process_gss_msg(&mut self) -> PsqlResult<()> {
646        // We don't support GSSAPI, so we just say no gracefully.
647        self.stream.write(BeMessage::EncryptionResponseNo).await?;
648        Ok(())
649    }
650
651    async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
652        if let Some(context) = self.tls_context.as_ref() {
653            // If got and ssl context, say yes for ssl connection.
654            // Construct ssl stream and replace with current one.
655            self.stream.write(BeMessage::EncryptionResponseSsl).await?;
656            self.stream.upgrade_to_ssl(context).await?;
657        } else {
658            // If no, say no for encryption.
659            self.stream.write(BeMessage::EncryptionResponseNo).await?;
660        }
661
662        Ok(())
663    }
664
665    async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
666        // Check SSL enforcement: if SSL is enforced but connection is not using SSL, reject
667        if let Some(ref tls_config) = self.tls_config
668            && tls_config.enforce_ssl
669            && !self.stream.is_ssl_connection().await
670        {
671            return Err(PsqlError::StartupError(
672                "SSL connection is required but not established".into(),
673            ));
674        }
675
676        let db_name = msg
677            .config
678            .get("database")
679            .cloned()
680            .unwrap_or_else(|| "dev".to_owned());
681        let user_name = msg
682            .config
683            .get("user")
684            .cloned()
685            .unwrap_or_else(|| "root".to_owned());
686
687        let session = self
688            .session_mgr
689            .connect(&db_name, &user_name, self.peer_addr.clone())
690            .map_err(|e| PsqlError::StartupError(e.into()))?;
691
692        if let Some(options) = msg.config.get("options") {
693            for (key, value) in parse_options(options)? {
694                session
695                    .set_config(&key, value)
696                    .map_err(|e| PsqlError::StartupError(e.into()))?;
697            }
698        }
699        // dedicated `application_name` has higher priority than `options`
700        let application_name = msg.config.get("application_name");
701        if let Some(application_name) = application_name {
702            session
703                .set_config("application_name", application_name.clone())
704                .map_err(|e| PsqlError::StartupError(e.into()))?;
705        }
706
707        match session.user_authenticator() {
708            UserAuthenticator::None => {
709                self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
710
711                // Cancel request need this for identify and verification. According to postgres
712                // doc, it should be written to buffer after receive AuthenticationOk.
713                self.stream
714                    .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
715
716                self.stream.write_no_flush(BeMessage::ParameterStatus(
717                    BeParameterStatusMessage::TimeZone(
718                        &session
719                            .get_config("timezone")
720                            .map_err(|e| PsqlError::StartupError(e.into()))?,
721                    ),
722                ))?;
723                self.stream
724                    .write_parameter_status_msg_no_flush(&ParameterStatus {
725                        application_name: application_name.cloned(),
726                    })?;
727                self.ready_for_query()?;
728            }
729            UserAuthenticator::ClearText(_)
730            | UserAuthenticator::OAuth { .. }
731            | UserAuthenticator::Ldap(..) => {
732                self.stream
733                    .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
734            }
735            UserAuthenticator::Md5WithSalt { salt, .. } => {
736                self.stream
737                    .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
738            }
739        }
740
741        self.session = Some(session);
742        self.state = PgProtocolState::Regular;
743        Ok(())
744    }
745
746    async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
747        let session = self.session.as_ref().unwrap();
748        let authenticator = session.user_authenticator();
749        authenticator.authenticate(&msg.password).await?;
750        self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
751        let timezone = session
752            .get_config("timezone")
753            .map_err(|e| PsqlError::StartupError(e.into()))?;
754        self.stream.write_no_flush(BeMessage::ParameterStatus(
755            BeParameterStatusMessage::TimeZone(&timezone),
756        ))?;
757        self.stream
758            .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
759        self.ready_for_query()?;
760        self.state = PgProtocolState::Regular;
761        Ok(())
762    }
763
764    fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
765        let session_id = (m.target_process_id, m.target_secret_key);
766        tracing::trace!("cancel query in session: {:?}", session_id);
767        self.session_mgr.cancel_queries_in_session(session_id);
768        self.session_mgr.cancel_creating_jobs_in_session(session_id);
769        self.is_terminate = true;
770        Ok(())
771    }
772
773    async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
774        let truncated_sql =
775            record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
776        let session = self.session.clone().unwrap();
777
778        session.check_idle_in_transaction_timeout()?;
779        // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
780        let _exec_context_guard = session.init_exec_context(truncated_sql.into());
781        self.inner_process_query_msg(sql, session.clone()).await
782    }
783
784    async fn inner_process_query_msg(
785        &mut self,
786        sql: Arc<str>,
787        session: Arc<SM::Session>,
788    ) -> PsqlResult<()> {
789        // Parse sql.
790        let stmts =
791            Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
792        // The following inner_process_query_msg_one_stmt can be slow. Release potential large String early.
793        drop(sql);
794        if stmts.is_empty() {
795            self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
796        }
797
798        // Execute multiple statements in simple query. KISS later.
799        for stmt in stmts {
800            self.inner_process_query_msg_one_stmt(stmt, session.clone())
801                .await?;
802        }
803        // Put this line inside the for loop above will lead to unfinished/stuck regress test...Not
804        // sure the reason.
805        self.ready_for_query()?;
806        Ok(())
807    }
808
809    async fn inner_process_query_msg_one_stmt(
810        &mut self,
811        stmt: Statement,
812        session: Arc<SM::Session>,
813    ) -> PsqlResult<()> {
814        let session = session.clone();
815
816        // execute query
817        let res = session.clone().run_one_query(stmt, Format::Text).await;
818
819        // Take all remaining notices (if any) and send them before `CommandComplete`.
820        while let Some(notice) = session.next_notice().now_or_never() {
821            self.stream
822                .write_no_flush(BeMessage::NoticeResponse(&notice))?;
823        }
824
825        let mut res = res.map_err(|e| PsqlError::SimpleQueryError(e.into()))?;
826
827        for notice in res.notices() {
828            self.stream
829                .write_no_flush(BeMessage::NoticeResponse(notice))?;
830        }
831
832        let status = res.status();
833        if let Some(ref application_name) = status.application_name {
834            self.stream.write_no_flush(BeMessage::ParameterStatus(
835                BeParameterStatusMessage::ApplicationName(application_name),
836            ))?;
837        }
838
839        if res.is_copy_query_to_stdout() {
840            self.stream
841                .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
842            let mut count = 0;
843            while let Some(row_set) = res.values_stream().next().await {
844                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
845                for row in row_set {
846                    self.stream.write_no_flush(BeMessage::CopyData(&row))?;
847                    count += 1;
848                }
849            }
850
851            self.stream.write_no_flush(BeMessage::CopyDone)?;
852
853            // Run the callback before sending the `CommandComplete` message.
854            res.run_callback().await?;
855
856            self.stream
857                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
858                    stmt_type: res.stmt_type(),
859                    rows_cnt: count,
860                }))?;
861        } else if res.is_query() {
862            self.stream
863                .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
864
865            let mut rows_cnt = 0;
866
867            while let Some(row_set) = res.values_stream().next().await {
868                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
869                for row in row_set {
870                    self.stream.write_no_flush(BeMessage::DataRow(&row))?;
871                    rows_cnt += 1;
872                }
873            }
874
875            // Run the callback before sending the `CommandComplete` message.
876            res.run_callback().await?;
877
878            self.stream
879                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
880                    stmt_type: res.stmt_type(),
881                    rows_cnt,
882                }))?;
883        } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
884            let first_row_set = res.values_stream().next().await;
885            let first_row_set = match first_row_set {
886                None => {
887                    return Err(PsqlError::Uncategorized(
888                        anyhow::anyhow!("no affected rows in output").into(),
889                    ));
890                }
891                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
892            };
893            let affected_rows_str = first_row_set[0].values()[0]
894                .as_ref()
895                .expect("compute node should return affected rows in output");
896
897            assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
898            let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
899                .unwrap()
900                .parse()
901                .unwrap_or_default();
902
903            // Run the callback before sending the `CommandComplete` message.
904            res.run_callback().await?;
905
906            self.stream
907                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
908                    stmt_type: res.stmt_type(),
909                    rows_cnt: affected_rows_cnt,
910                }))?;
911        } else {
912            // Run the callback before sending the `CommandComplete` message.
913            res.run_callback().await?;
914
915            self.stream
916                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
917                    stmt_type: res.stmt_type(),
918                    rows_cnt: 0,
919                }))?;
920        }
921
922        Ok(())
923    }
924
925    fn process_terminate(&mut self) {
926        self.is_terminate = true;
927    }
928
929    fn process_health_check(&mut self) {
930        tracing::debug!("health check");
931        self.is_terminate = true;
932    }
933
934    async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
935        let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
936        record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
937        let session = self.session.clone().unwrap();
938        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
939        let type_ids = std::mem::take(&mut msg.type_ids);
940        // The inner_process_parse_msg can be slow. Release potential large FeParseMessage early.
941        drop(msg);
942        self.inner_process_parse_msg(session, sql, statement_name, type_ids)
943            .await?;
944        Ok(())
945    }
946
947    async fn inner_process_parse_msg(
948        &mut self,
949        session: Arc<SM::Session>,
950        sql: Arc<str>,
951        statement_name: String,
952        type_ids: Vec<i32>,
953    ) -> PsqlResult<()> {
954        if statement_name.is_empty() {
955            // Remove the unnamed prepare statement first, in case the unsupported sql binds a wrong
956            // prepare statement.
957            self.unnamed_prepare_statement.take();
958        } else if self.prepare_statement_store.contains_key(&statement_name) {
959            return Err(PsqlError::ExtendedPrepareError(
960                "Duplicated statement name".into(),
961            ));
962        }
963
964        let stmt = {
965            let stmts = Parser::parse_sql(&sql)
966                .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
967            drop(sql);
968            if stmts.len() > 1 {
969                return Err(PsqlError::ExtendedPrepareError(
970                    "Only one statement is allowed in extended query mode".into(),
971                ));
972            }
973
974            stmts.into_iter().next()
975        };
976
977        let param_types: Vec<Option<DataType>> = type_ids
978            .iter()
979            .map(|&id| {
980                // 0 means unspecified type
981                // ref: https://www.postgresql.org/docs/15/protocol-message-formats.html#:~:text=Placing%20a%20zero%20here%20is%20equivalent%20to%20leaving%20the%20type%20unspecified.
982                if id == 0 {
983                    Ok(None)
984                } else {
985                    DataType::from_oid(id)
986                        .map(Some)
987                        .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
988                }
989            })
990            .try_collect()?;
991
992        let prepare_statement = session
993            .parse(stmt, param_types)
994            .await
995            .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))?;
996
997        if statement_name.is_empty() {
998            self.unnamed_prepare_statement.replace(prepare_statement);
999        } else {
1000            self.prepare_statement_store
1001                .insert(statement_name.clone(), prepare_statement);
1002        }
1003
1004        self.statement_portal_dependency
1005            .entry(statement_name)
1006            .or_default()
1007            .clear();
1008
1009        self.stream.write_no_flush(BeMessage::ParseComplete)?;
1010        Ok(())
1011    }
1012
1013    fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
1014        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
1015        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1016        let session = self.session.clone().unwrap();
1017
1018        if self.portal_store.contains_key(&portal_name) {
1019            return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
1020        }
1021
1022        let prepare_statement = self.get_statement(&statement_name)?;
1023
1024        let result_formats = msg
1025            .result_format_codes
1026            .iter()
1027            .map(|&format_code| Format::from_i16(format_code))
1028            .try_collect()?;
1029        let param_formats = msg
1030            .param_format_codes
1031            .iter()
1032            .map(|&format_code| Format::from_i16(format_code))
1033            .try_collect()?;
1034
1035        let portal = session
1036            .bind(prepare_statement, msg.params, param_formats, result_formats)
1037            .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1038
1039        if portal_name.is_empty() {
1040            self.result_cache.remove(&portal_name);
1041            self.unnamed_portal.replace(portal);
1042        } else {
1043            assert!(
1044                !self.result_cache.contains_key(&portal_name),
1045                "Named portal never can be overridden."
1046            );
1047            self.portal_store.insert(portal_name.clone(), portal);
1048        }
1049
1050        self.statement_portal_dependency
1051            .get_mut(&statement_name)
1052            .unwrap()
1053            .push(portal_name);
1054
1055        self.stream.write_no_flush(BeMessage::BindComplete)?;
1056        Ok(())
1057    }
1058
1059    async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1060        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1061        let row_max = msg.max_rows as usize;
1062        drop(msg);
1063        let session = self.session.clone().unwrap();
1064
1065        match self.result_cache.remove(&portal_name) {
1066            Some(mut result_cache) => {
1067                assert!(self.portal_store.contains_key(&portal_name));
1068
1069                let is_consume_completed =
1070                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1071
1072                if !is_consume_completed {
1073                    self.result_cache.insert(portal_name, result_cache);
1074                }
1075            }
1076            _ => {
1077                let portal = self.get_portal(&portal_name)?;
1078                let sql = format!("{}", portal);
1079                let truncated_sql =
1080                    record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
1081                drop(sql);
1082
1083                session.check_idle_in_transaction_timeout()?;
1084                // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
1085                let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1086                let result = session.clone().execute(portal).await;
1087
1088                let pg_response = result.map_err(|e| PsqlError::ExtendedExecuteError(e.into()))?;
1089                let mut result_cache = ResultCache::new(pg_response);
1090                let is_consume_completed =
1091                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1092                if !is_consume_completed {
1093                    self.result_cache.insert(portal_name, result_cache);
1094                }
1095            }
1096        }
1097
1098        Ok(())
1099    }
1100
1101    fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1102        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1103        let session = self.session.clone().unwrap();
1104        //  b'S' => Statement
1105        //  b'P' => Portal
1106
1107        assert!(msg.kind == b'S' || msg.kind == b'P');
1108        if msg.kind == b'S' {
1109            let prepare_statement = self.get_statement(&name)?;
1110
1111            let (param_types, row_descriptions) = self
1112                .session
1113                .clone()
1114                .unwrap()
1115                .describe_statement(prepare_statement)
1116                .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1117            self.stream.write_no_flush(BeMessage::ParameterDescription(
1118                &param_types.iter().map(|t| t.to_oid()).collect_vec(),
1119            ))?;
1120
1121            if row_descriptions.is_empty() {
1122                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1123                // return NoData message if the statement is not a query.
1124                self.stream.write_no_flush(BeMessage::NoData)?;
1125            } else {
1126                self.stream
1127                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1128            }
1129        } else if msg.kind == b'P' {
1130            let portal = self.get_portal(&name)?;
1131
1132            let row_descriptions = session
1133                .describe_portal(portal)
1134                .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1135
1136            if row_descriptions.is_empty() {
1137                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1138                // return NoData message if the statement is not a query.
1139                self.stream.write_no_flush(BeMessage::NoData)?;
1140            } else {
1141                self.stream
1142                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1143            }
1144        }
1145        Ok(())
1146    }
1147
1148    fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1149        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1150        assert!(msg.kind == b'S' || msg.kind == b'P');
1151        if msg.kind == b'S' {
1152            if name.is_empty() {
1153                self.unnamed_prepare_statement = None;
1154            } else {
1155                self.prepare_statement_store.remove(&name);
1156            }
1157            for portal_name in self
1158                .statement_portal_dependency
1159                .remove(&name)
1160                .unwrap_or_default()
1161            {
1162                self.remove_portal(&portal_name);
1163            }
1164        } else if msg.kind == b'P' {
1165            self.remove_portal(&name);
1166        }
1167        self.stream.write_no_flush(BeMessage::CloseComplete)?;
1168        Ok(())
1169    }
1170
1171    fn remove_portal(&mut self, portal_name: &str) {
1172        if portal_name.is_empty() {
1173            self.unnamed_portal = None;
1174        } else {
1175            self.portal_store.remove(portal_name);
1176        }
1177        self.result_cache.remove(portal_name);
1178    }
1179
1180    fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1181        if portal_name.is_empty() {
1182            Ok(self
1183                .unnamed_portal
1184                .as_ref()
1185                .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1186                .clone())
1187        } else {
1188            Ok(self
1189                .portal_store
1190                .get(portal_name)
1191                .ok_or_else(|| {
1192                    PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1193                })?
1194                .clone())
1195        }
1196    }
1197
1198    fn get_statement(
1199        &self,
1200        statement_name: &str,
1201    ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1202        if statement_name.is_empty() {
1203            Ok(self
1204                .unnamed_prepare_statement
1205                .as_ref()
1206                .ok_or_else(|| {
1207                    PsqlError::Uncategorized("unnamed prepare statement not found".into())
1208                })?
1209                .clone())
1210        } else {
1211            Ok(self
1212                .prepare_statement_store
1213                .get(statement_name)
1214                .ok_or_else(|| {
1215                    PsqlError::Uncategorized(
1216                        format!("Prepare statement {} not found", statement_name).into(),
1217                    )
1218                })?
1219                .clone())
1220        }
1221    }
1222}
1223
1224enum PgStreamInner<S> {
1225    /// Used for the intermediate state when converting from unencrypted to ssl stream.
1226    Placeholder,
1227    /// An unencrypted stream.
1228    Unencrypted(S),
1229    /// An ssl stream.
1230    Ssl(SslStream<S>),
1231}
1232
1233/// Trait for a byte stream that can be used for pg protocol.
1234pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1235impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1236
1237/// Wraps a byte stream and read/write pg messages.
1238///
1239/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
1240/// so that it can be used to write messages concurrently without interference.
1241pub struct PgStream<S> {
1242    /// The underlying stream.
1243    stream: Arc<Mutex<PgStreamInner<S>>>,
1244    /// Write into buffer before flush to stream.
1245    write_buf: BytesMut,
1246    read_header: Option<FeMessageHeader>,
1247}
1248
1249impl<S> PgStream<S> {
1250    /// Create a new `PgStream` with the given stream and default write buffer capacity.
1251    pub fn new(stream: S) -> Self {
1252        const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1253
1254        Self {
1255            stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1256            write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1257            read_header: None,
1258        }
1259    }
1260
1261    /// Check if the current connection is using SSL
1262    async fn is_ssl_connection(&self) -> bool {
1263        let stream = self.stream.lock().await;
1264        matches!(*stream, PgStreamInner::Ssl(_))
1265    }
1266}
1267
1268impl<S> Clone for PgStream<S> {
1269    fn clone(&self) -> Self {
1270        Self {
1271            stream: Arc::clone(&self.stream),
1272            write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1273            read_header: self.read_header.clone(),
1274        }
1275    }
1276}
1277
1278/// At present there is a hard-wired set of parameters for which
1279/// ParameterStatus will be generated: they are:
1280///
1281///  * `server_version`
1282///  * `server_encoding`
1283///  * `client_encoding`
1284///  * `application_name`
1285///  * `is_superuser`
1286///  * `session_authorization`
1287///  * `DateStyle`
1288///  * `IntervalStyle`
1289///  * `TimeZone`
1290///  * `integer_datetimes`
1291///  * `standard_conforming_string`
1292///
1293/// See: <https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC>.
1294#[derive(Debug, Default, Clone)]
1295pub struct ParameterStatus {
1296    pub application_name: Option<String>,
1297}
1298
1299impl<S> PgStream<S>
1300where
1301    S: PgByteStream,
1302{
1303    async fn read_startup(&mut self) -> io::Result<FeMessage> {
1304        let mut stream = self.stream.lock().await;
1305        match &mut *stream {
1306            PgStreamInner::Placeholder => unreachable!(),
1307            PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1308            PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1309        }
1310    }
1311
1312    async fn read_header(&mut self) -> io::Result<()> {
1313        let mut stream = self.stream.lock().await;
1314        match &mut *stream {
1315            PgStreamInner::Placeholder => unreachable!(),
1316            PgStreamInner::Unencrypted(stream) => {
1317                self.read_header = Some(FeMessage::read_header(stream).await?);
1318                Ok(())
1319            }
1320            PgStreamInner::Ssl(ssl_stream) => {
1321                self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1322                Ok(())
1323            }
1324        }
1325    }
1326
1327    async fn read_body(&mut self) -> io::Result<FeMessage> {
1328        let mut stream = self.stream.lock().await;
1329        let header = self
1330            .read_header
1331            .take()
1332            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1333        match &mut *stream {
1334            PgStreamInner::Placeholder => unreachable!(),
1335            PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1336            PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1337        }
1338    }
1339
1340    async fn skip_body(&mut self) -> io::Result<()> {
1341        let mut stream = self.stream.lock().await;
1342        let header = self
1343            .read_header
1344            .take()
1345            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1346        match &mut *stream {
1347            PgStreamInner::Placeholder => unreachable!(),
1348            PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1349            PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1350        }
1351    }
1352
1353    fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1354        self.write_no_flush(BeMessage::ParameterStatus(
1355            BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1356        ))?;
1357        self.write_no_flush(BeMessage::ParameterStatus(
1358            BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1359        ))?;
1360        self.write_no_flush(BeMessage::ParameterStatus(
1361            BeParameterStatusMessage::ServerVersion(PG_VERSION),
1362        ))?;
1363        if let Some(application_name) = &status.application_name {
1364            self.write_no_flush(BeMessage::ParameterStatus(
1365                BeParameterStatusMessage::ApplicationName(application_name),
1366            ))?;
1367        }
1368        Ok(())
1369    }
1370
1371    pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1372        BeMessage::write(&mut self.write_buf, message)
1373    }
1374
1375    async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1376        self.write_no_flush(message)?;
1377        self.flush().await?;
1378        Ok(())
1379    }
1380
1381    async fn flush(&mut self) -> io::Result<()> {
1382        let mut stream = self.stream.lock().await;
1383        match &mut *stream {
1384            PgStreamInner::Placeholder => unreachable!(),
1385            PgStreamInner::Unencrypted(stream) => {
1386                stream.write_all(&self.write_buf).await?;
1387                stream.flush().await?;
1388            }
1389            PgStreamInner::Ssl(ssl_stream) => {
1390                ssl_stream.write_all(&self.write_buf).await?;
1391                ssl_stream.flush().await?;
1392            }
1393        }
1394        self.write_buf.clear();
1395        Ok(())
1396    }
1397}
1398
1399impl<S> PgStream<S>
1400where
1401    S: PgByteStream,
1402{
1403    /// Convert the underlying stream to ssl stream based on the given context.
1404    async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1405        let mut stream = self.stream.lock().await;
1406
1407        match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1408            PgStreamInner::Unencrypted(unencrypted_stream) => {
1409                let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1410                let mut ssl_stream =
1411                    tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1412
1413                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1414                    tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1415                    let _ = ssl_stream.shutdown().await;
1416                    return Err(e.into());
1417                }
1418
1419                *stream = PgStreamInner::Ssl(ssl_stream);
1420            }
1421            PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1422            PgStreamInner::Placeholder => unreachable!(),
1423        }
1424
1425        Ok(())
1426    }
1427}
1428
1429fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1430    let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1431
1432    let key_path = &tls_config.key;
1433    let cert_path = &tls_config.cert;
1434
1435    // Build ssl acceptor according to the config.
1436    // Now we set every verify to true.
1437    acceptor
1438        .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1439        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1440    acceptor
1441        .set_ca_file(cert_path)
1442        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1443    acceptor
1444        .set_certificate_chain_file(cert_path)
1445        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1446    let acceptor = acceptor.build();
1447
1448    Ok(acceptor.into_context())
1449}
1450
1451pub mod truncated_fmt {
1452    use std::fmt::*;
1453
1454    struct TruncatedFormatter<'a, 'b> {
1455        remaining: usize,
1456        finished: bool,
1457        f: &'a mut Formatter<'b>,
1458    }
1459    impl Write for TruncatedFormatter<'_, '_> {
1460        fn write_str(&mut self, s: &str) -> Result {
1461            if self.finished {
1462                return Ok(());
1463            }
1464
1465            if self.remaining < s.len() {
1466                let actual = s.floor_char_boundary(self.remaining);
1467                self.f.write_str(&s[0..actual])?;
1468                self.remaining -= actual;
1469                self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1470                self.finished = true; // so that ...(truncated) is printed exactly once
1471            } else {
1472                self.f.write_str(s)?;
1473                self.remaining -= s.len();
1474            }
1475            Ok(())
1476        }
1477    }
1478
1479    pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1480
1481    impl<T> Debug for TruncatedFmt<'_, T>
1482    where
1483        T: Debug,
1484    {
1485        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1486            TruncatedFormatter {
1487                remaining: self.1,
1488                finished: false,
1489                f,
1490            }
1491            .write_fmt(format_args!("{:?}", self.0))
1492        }
1493    }
1494
1495    impl<T> Display for TruncatedFmt<'_, T>
1496    where
1497        T: Display,
1498    {
1499        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1500            TruncatedFormatter {
1501                remaining: self.1,
1502                finished: false,
1503                f,
1504            }
1505            .write_fmt(format_args!("{}", self.0))
1506        }
1507    }
1508
1509    #[cfg(test)]
1510    mod tests {
1511        use super::*;
1512
1513        #[test]
1514        fn test_trunc_utf8() {
1515            assert_eq!(
1516                format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1517                "select '...(truncated,14)",
1518            );
1519        }
1520    }
1521}
1522
1523/// Handle `options` in `StartupMessage` from client
1524///
1525/// It is like shell arguments but only respects backslash-escape and space;
1526/// quotes have no special meaning and are handled literally.
1527///
1528/// PostgreSQL allows both `-c key=value` and `--key=value`.
1529///
1530/// `key-name` is normalized as `key_name`.
1531///
1532/// * <https://github.com/postgres/postgres/blob/REL_18_1/src/backend/utils/init/postinit.c#L487>
1533/// * <https://github.com/postgres/postgres/blob/REL_18_1/src/backend/tcop/postgres.c#L3866>
1534/// * <https://github.com/postgres/postgres/blob/REL_18_1/src/backend/utils/misc/guc.c#L6361>
1535fn parse_options(options: &str) -> PsqlResult<Vec<(String, String)>> {
1536    let mut args = Vec::new();
1537    let mut current_arg = String::new();
1538    let mut chars = options.chars().peekable();
1539
1540    while let Some(c) = chars.next() {
1541        if c == '\\' {
1542            if let Some(next_c) = chars.next() {
1543                current_arg.push(next_c);
1544            }
1545        } else if c.is_ascii_whitespace() {
1546            if !current_arg.is_empty() {
1547                args.push(std::mem::take(&mut current_arg));
1548            }
1549        } else {
1550            current_arg.push(c);
1551        }
1552    }
1553    if !current_arg.is_empty() {
1554        args.push(current_arg);
1555    }
1556
1557    let mut args_iter = args.into_iter();
1558    let mut config = Vec::new();
1559
1560    while let Some(arg) = args_iter.next() {
1561        if arg == "-c" {
1562            if let Some(config_str) = args_iter.next() {
1563                if let Some((key, value)) = config_str.split_once('=') {
1564                    let key = key.replace("-", "_");
1565                    config.push((key, value.to_owned()));
1566                } else {
1567                    return Err(PsqlError::StartupError(
1568                        format!("invalid config format: {}", config_str).into(),
1569                    ));
1570                }
1571            } else {
1572                return Err(PsqlError::StartupError("missing argument for -c".into()));
1573            }
1574        } else if let Some(config_str) = arg.strip_prefix("--") {
1575            if let Some((key, value)) = config_str.split_once('=') {
1576                let key = key.replace("-", "_");
1577                config.push((key, value.to_owned()));
1578            } else {
1579                return Err(PsqlError::StartupError(
1580                    format!("invalid config format: {}", config_str).into(),
1581                ));
1582            }
1583        } else {
1584            tracing::warn!(
1585                arg,
1586                "ignoring unrecognized option for backward compatibility"
1587            );
1588        }
1589    }
1590    Ok(config)
1591}
1592
1593#[cfg(test)]
1594mod tests {
1595    use std::collections::HashSet;
1596
1597    use super::*;
1598
1599    #[test]
1600    fn test_redact_parsable_sql() {
1601        let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1602        let sql = r"
1603        create source temp (k bigint, v varchar) with (
1604            connector = 'datagen',
1605            v1 = 123,
1606            v2 = 'with',
1607            v3 = false,
1608            v4 = '',
1609        ) FORMAT plain ENCODE json (a='1',b='2')
1610        ";
1611        assert_eq!(
1612            redact_sql(sql, keywords),
1613            "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1614        );
1615    }
1616
1617    #[test]
1618    fn test_parse_options() {
1619        assert_eq!(parse_options("").unwrap(), vec![]);
1620        assert_eq!(
1621            parse_options("-c a=1 -c b=2").unwrap(),
1622            vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1623        );
1624        assert_eq!(
1625            parse_options("-c   key=value").unwrap(),
1626            vec![("key".into(), "value".into())]
1627        );
1628        // Custom parser treats quotes as normal characters, so they are included in value
1629        assert_eq!(
1630            parse_options("-c key='value'").unwrap(),
1631            vec![("key".into(), "'value'".into())]
1632        );
1633
1634        // Test backslash escaping for spaces (standard Postgres way)
1635        assert_eq!(
1636            parse_options(r#"-c key=value\ with\ spaces"#).unwrap(),
1637            vec![("key".into(), "value with spaces".into())]
1638        );
1639        assert_eq!(
1640            parse_options(r#"-c search_path=my\ schema"#).unwrap(),
1641            vec![("search_path".into(), "my schema".into())]
1642        );
1643
1644        assert!(parse_options("-c").is_err());
1645        assert!(parse_options("-c foo").is_err()); // missing =
1646        assert!(parse_options("--foo").is_err()); // missing = in -- option
1647
1648        assert_eq!(
1649            parse_options("--foo=bar").unwrap(),
1650            vec![("foo".into(), "bar".into())]
1651        );
1652        assert_eq!(
1653            parse_options(r#"--foo=bar\ baz"#).unwrap(),
1654            vec![("foo".into(), "bar baz".into())]
1655        );
1656        assert_eq!(
1657            parse_options("-c a=1 --b=2").unwrap(),
1658            vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1659        );
1660        // Unpaired trailing backslash is silently dropped, same as PostgreSQL
1661        assert_eq!(
1662            parse_options(r#"-c a=b\"#).unwrap(),
1663            vec![("a".into(), "b".into())]
1664        );
1665    }
1666}