pgwire/
pg_protocol.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::panic::FutureCatchUnwindExt;
33use risingwave_common::util::query_log::*;
34use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
35use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
36use risingwave_sqlparser::parser::Parser;
37use thiserror_ext::AsReport;
38use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
39use tokio::sync::Mutex;
40use tokio_openssl::SslStream;
41use tracing::Instrument;
42
43use crate::error::{PsqlError, PsqlResult};
44use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
45use crate::net::AddressRef;
46use crate::pg_extended::ResultCache;
47use crate::pg_message::{
48    BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
49    FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
50    FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
51};
52use crate::pg_server::{Session, SessionManager, UserAuthenticator};
53use crate::types::Format;
54
55/// Truncates query log if it's longer than `RW_QUERY_LOG_TRUNCATE_LEN`, to avoid log file being too
56/// large.
57static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
58    LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
59        Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
60        _ => {
61            if cfg!(debug_assertions) {
62                65536
63            } else {
64                1024
65            }
66        }
67    });
68
69tokio::task_local! {
70    /// The current session. Concrete type is erased for different session implementations.
71    pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
72}
73
74/// The state machine for each psql connection.
75/// Read pg messages from tcp stream and write results back.
76pub struct PgProtocol<S, SM>
77where
78    SM: SessionManager,
79{
80    /// Used for write/read pg messages.
81    stream: PgStream<S>,
82    /// Current states of pg connection.
83    state: PgProtocolState,
84    /// Whether the connection is terminated.
85    is_terminate: bool,
86
87    session_mgr: Arc<SM>,
88    session: Option<Arc<SM::Session>>,
89
90    result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
91    unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
92    prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
93    unnamed_portal: Option<<SM::Session as Session>::Portal>,
94    portal_store: HashMap<String, <SM::Session as Session>::Portal>,
95    // Used to store the dependency of portal and prepare statement.
96    // When we close a prepare statement, we need to close all the portals that depend on it.
97    statement_portal_dependency: HashMap<String, Vec<String>>,
98
99    // Used for ssl connection.
100    // If None, not expected to build ssl connection (panic).
101    tls_context: Option<SslContext>,
102
103    // Used in extended query protocol. When encounter error in extended query, we need to ignore
104    // the following message util sync message.
105    ignore_util_sync: bool,
106
107    // Client Address
108    peer_addr: AddressRef,
109
110    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
111    message_memory_manager: MessageMemoryManagerRef,
112}
113
114/// Configures TLS encryption for connections.
115#[derive(Debug, Clone)]
116pub struct TlsConfig {
117    /// The path to the TLS certificate.
118    pub cert: String,
119    /// The path to the TLS key.
120    pub key: String,
121}
122
123impl TlsConfig {
124    pub fn new_default() -> Option<Self> {
125        let cert = std::env::var("RW_SSL_CERT").ok()?;
126        let key = std::env::var("RW_SSL_KEY").ok()?;
127        tracing::info!("RW_SSL_CERT={}, RW_SSL_KEY={}", cert, key);
128        Some(Self { cert, key })
129    }
130}
131
132impl<S, SM> Drop for PgProtocol<S, SM>
133where
134    SM: SessionManager,
135{
136    fn drop(&mut self) {
137        if let Some(session) = &self.session {
138            // Clear the session in session manager.
139            self.session_mgr.end_session(session);
140        }
141    }
142}
143
144/// States flow happened from top to down.
145enum PgProtocolState {
146    Startup,
147    Regular,
148}
149
150/// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations).
151///
152/// PG protocol strings are always C strings.
153pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
154    let without_null = if b.last() == Some(&0) {
155        &b[..b.len() - 1]
156    } else {
157        &b[..]
158    };
159    std::str::from_utf8(without_null)
160}
161
162/// Record `sql` in the current tracing span.
163fn record_sql_in_span(
164    sql: &str,
165    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
166) -> String {
167    let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
168        && !keywords.is_empty()
169    {
170        redact_sql(sql, keywords)
171    } else {
172        sql.to_owned()
173    };
174    let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
175    tracing::Span::current().record("sql", tracing::field::display(&truncated));
176    truncated.to_string()
177}
178
179/// Redacts SQL options. Data in DML is not redacted.
180fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
181    match Parser::parse_sql(sql) {
182        Ok(sqls) => sqls
183            .into_iter()
184            .map(|sql| sql.to_redacted_string(keywords.clone()))
185            .join(";"),
186        Err(_) => sql.to_owned(),
187    }
188}
189
190#[derive(Clone)]
191pub struct ConnectionContext {
192    pub tls_config: Option<TlsConfig>,
193    pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
194    pub message_memory_manager: MessageMemoryManagerRef,
195}
196
197impl<S, SM> PgProtocol<S, SM>
198where
199    S: PgByteStream,
200    SM: SessionManager,
201{
202    pub fn new(
203        stream: S,
204        session_mgr: Arc<SM>,
205        peer_addr: AddressRef,
206        context: ConnectionContext,
207    ) -> Self {
208        let ConnectionContext {
209            tls_config,
210            redact_sql_option_keywords,
211            message_memory_manager,
212        } = context;
213        Self {
214            stream: PgStream::new(stream),
215            is_terminate: false,
216            state: PgProtocolState::Startup,
217            session_mgr,
218            session: None,
219            tls_context: tls_config
220                .as_ref()
221                .and_then(|e| build_ssl_ctx_from_config(e).ok()),
222            result_cache: Default::default(),
223            unnamed_prepare_statement: Default::default(),
224            prepare_statement_store: Default::default(),
225            unnamed_portal: Default::default(),
226            portal_store: Default::default(),
227            statement_portal_dependency: Default::default(),
228            ignore_util_sync: false,
229            peer_addr,
230            redact_sql_option_keywords,
231            message_memory_manager,
232        }
233    }
234
235    /// Run the protocol to serve the connection.
236    pub async fn run(&mut self) {
237        let mut notice_fut = None;
238
239        loop {
240            // Once a session is present, create a future to subscribe and send notices asynchronously.
241            if notice_fut.is_none()
242                && let Some(session) = self.session.clone()
243            {
244                let mut stream = self.stream.clone();
245                notice_fut = Some(Box::pin(async move {
246                    loop {
247                        let notice = session.next_notice().await;
248                        if let Err(e) = stream.write(&BeMessage::NoticeResponse(&notice)).await {
249                            tracing::error!(error = %e.as_report(), notice, "failed to send notice");
250                        }
251                    }
252                }));
253            }
254
255            // Read and process messages.
256            let process = std::pin::pin!(async {
257                let (msg, _memory_guard) = match self.read_message().await {
258                    Ok(msg) => msg,
259                    Err(e) => {
260                        tracing::error!(error = %e.as_report(), "error when reading message");
261                        return true; // terminate the connection
262                    }
263                };
264                tracing::trace!(?msg, "received message");
265                self.process(msg).await
266            });
267
268            let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
269                tokio::select! {
270                    _ = notice_fut => unreachable!(),
271                    terminated = process => terminated,
272                }
273            } else {
274                process.await
275            };
276
277            if terminated {
278                break;
279            }
280        }
281    }
282
283    /// Processes one message. Returns true if the connection is terminated.
284    pub async fn process(&mut self, msg: FeMessage) -> bool {
285        self.do_process(msg).await.is_none() || self.is_terminate
286    }
287
288    /// The root tracing span for processing a message. The target of the span is
289    /// [`PGWIRE_ROOT_SPAN_TARGET`].
290    ///
291    /// This is used to provide context for the (slow) query logs and traces.
292    ///
293    /// The span is only effective if there's a current session and the message is
294    /// query-related. Otherwise, `Span::none()` is returned.
295    fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
296        let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
297            return tracing::Span::none();
298        };
299
300        let mode = match msg {
301            FeMessage::Query(_) => "simple query",
302            FeMessage::Parse(_) => "extended query parse",
303            FeMessage::Execute(_) => "extended query execute",
304            _ => return tracing::Span::none(),
305        };
306
307        tracing::info_span!(
308            target: PGWIRE_ROOT_SPAN_TARGET,
309            "handle_query",
310            mode,
311            session_id,
312            sql = tracing::field::Empty, // record SQL later in each `process` call
313        )
314    }
315
316    /// Return type `Option<()>` is essentially a bool, but allows `?` for early return.
317    /// - `None` means to terminate the current connection
318    /// - `Some(())` means to continue processing the next message
319    async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
320        let span = self.root_span_for_msg(&msg);
321        let weak_session = self
322            .session
323            .as_ref()
324            .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
325
326        // Processing the message itself.
327        //
328        // Note: pin the future to avoid stack overflow as we'll wrap it multiple times
329        // in the following code.
330        let fut = Box::pin(self.do_process_inner(msg));
331
332        // Set the current session as the context when processing the message, if exists.
333        let fut = async move {
334            if let Some(session) = weak_session {
335                CURRENT_SESSION.scope(session, fut).await
336            } else {
337                fut.await
338            }
339        };
340
341        // Catch unwind.
342        let fut = async move {
343            AssertUnwindSafe(fut)
344                .rw_catch_unwind()
345                .await
346                .unwrap_or_else(|payload| {
347                    Err(PsqlError::Panic(
348                        panic_message::panic_message(&payload).to_owned(),
349                    ))
350                })
351        };
352
353        // Slow query log.
354        let fut = async move {
355            let period = *SLOW_QUERY_LOG_PERIOD;
356            let mut fut = std::pin::pin!(fut);
357            let mut elapsed = Duration::ZERO;
358
359            // Report the SQL in the log periodically if the query is slow.
360            loop {
361                match tokio::time::timeout(period, &mut fut).await {
362                    Ok(result) => break result,
363                    Err(_) => {
364                        elapsed += period;
365                        tracing::info!(
366                            target: PGWIRE_SLOW_QUERY_LOG,
367                            elapsed = %format_args!("{}ms", elapsed.as_millis()),
368                            "slow query"
369                        );
370                    }
371                }
372            }
373        };
374
375        // Query log.
376        let fut = async move {
377            let start = Instant::now();
378            let result = fut.await;
379            let elapsed = start.elapsed();
380
381            // Always log if an error occurs.
382            // Note: all messages will be processed through this code path, making it the
383            //       only necessary place to log errors.
384            if let Err(error) = &result {
385                if cfg!(debug_assertions) && !Deployment::current().is_ci() {
386                    // For local debugging, we print the error with backtrace.
387                    // It's useful only when:
388                    // - no additional context is added to the error
389                    // - backtrace is captured in the error
390                    // - backtrace is not printed in the middle
391                    tracing::error!(error = ?error.as_report(), "error when process message");
392                } else {
393                    tracing::error!(error = %error.as_report(), "error when process message");
394                }
395            }
396
397            // Log to optionally-enabled target `PGWIRE_QUERY_LOG`.
398            // Only log if we're currently in a tracing span set in `span_for_msg`.
399            if !tracing::Span::current().is_none() {
400                tracing::info!(
401                    target: PGWIRE_QUERY_LOG,
402                    status = if result.is_ok() { "ok" } else { "err" },
403                    time = %format_args!("{}ms", elapsed.as_millis()),
404                );
405            }
406
407            result
408        };
409
410        // Tracing span.
411        let fut = fut.instrument(span);
412
413        // Execute the future and handle the error.
414        match fut.await {
415            Ok(()) => Some(()),
416            Err(e) => {
417                match e {
418                    PsqlError::IoError(io_err) => {
419                        if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
420                            return None;
421                        }
422                    }
423
424                    PsqlError::SslError(_) => {
425                        // For ssl error, because the stream has already been consumed, so there is
426                        // no way to write more message.
427                        return None;
428                    }
429
430                    PsqlError::StartupError(_) | PsqlError::PasswordError => {
431                        self.stream
432                            .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
433                            .ok()?;
434                        let _ = self.stream.flush().await;
435                        return None;
436                    }
437
438                    PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
439                        self.stream
440                            .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
441                            .ok()?;
442                        self.ready_for_query().ok()?;
443                    }
444
445                    PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
446                        self.stream
447                            .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
448                            .ok()?;
449                        let _ = self.stream.flush().await;
450
451                        // 1. Catching the panic during message processing may leave the session in an
452                        // inconsistent state. We forcefully close the connection (then end the
453                        // session) here for safety.
454                        // 2. Idle in transaction timeout should also close the connection.
455                        return None;
456                    }
457
458                    PsqlError::Uncategorized(_)
459                    | PsqlError::ExtendedPrepareError(_)
460                    | PsqlError::ExtendedExecuteError(_) => {
461                        self.stream
462                            .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
463                            .ok()?;
464                    }
465                }
466                let _ = self.stream.flush().await;
467                Some(())
468            }
469        }
470    }
471
472    async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
473        // Ignore util sync message.
474        if self.ignore_util_sync {
475            if let FeMessage::Sync = msg {
476            } else {
477                tracing::trace!("ignore message {:?} until sync.", msg);
478                return Ok(());
479            }
480        }
481
482        match msg {
483            FeMessage::Gss => self.process_gss_msg().await?,
484            FeMessage::Ssl => self.process_ssl_msg().await?,
485            FeMessage::Startup(msg) => self.process_startup_msg(msg)?,
486            FeMessage::Password(msg) => self.process_password_msg(msg).await?,
487            FeMessage::Query(query_msg) => {
488                let sql = Arc::from(query_msg.get_sql()?);
489                // The process_query_msg can be slow. Release potential large FeQueryMessage early.
490                drop(query_msg);
491                self.process_query_msg(sql).await?
492            }
493            FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
494            FeMessage::Terminate => self.process_terminate(),
495            FeMessage::Parse(m) => {
496                if let Err(err) = self.process_parse_msg(m).await {
497                    self.ignore_util_sync = true;
498                    return Err(err);
499                }
500            }
501            FeMessage::Bind(m) => {
502                if let Err(err) = self.process_bind_msg(m) {
503                    self.ignore_util_sync = true;
504                    return Err(err);
505                }
506            }
507            FeMessage::Execute(m) => {
508                if let Err(err) = self.process_execute_msg(m).await {
509                    self.ignore_util_sync = true;
510                    return Err(err);
511                }
512            }
513            FeMessage::Describe(m) => {
514                if let Err(err) = self.process_describe_msg(m) {
515                    self.ignore_util_sync = true;
516                    return Err(err);
517                }
518            }
519            FeMessage::Sync => {
520                self.ignore_util_sync = false;
521                self.ready_for_query()?
522            }
523            FeMessage::Close(m) => {
524                if let Err(err) = self.process_close_msg(m) {
525                    self.ignore_util_sync = true;
526                    return Err(err);
527                }
528            }
529            FeMessage::Flush => {
530                if let Err(err) = self.stream.flush().await {
531                    self.ignore_util_sync = true;
532                    return Err(err.into());
533                }
534            }
535            FeMessage::HealthCheck => self.process_health_check(),
536            FeMessage::ServerThrottle(reason) => match reason {
537                ServerThrottleReason::TooLargeMessage => {
538                    return Err(PsqlError::ServerThrottle(format!(
539                        "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
540                        self.message_memory_manager.max_filter_bytes
541                    )));
542                }
543                ServerThrottleReason::TooManyMemoryUsage => {
544                    return Err(PsqlError::ServerThrottle(format!(
545                        "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
546                        self.message_memory_manager.max_running_bytes
547                    )));
548                }
549            },
550        }
551        self.stream.flush().await?;
552        Ok(())
553    }
554
555    pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
556        match self.state {
557            PgProtocolState::Startup => self
558                .stream
559                .read_startup()
560                .await
561                .map(|message: FeMessage| (message, None)),
562            PgProtocolState::Regular => {
563                self.stream.read_header().await?;
564                let guard = if let Some(ref header) = self.stream.read_header {
565                    let payload_len = std::cmp::max(header.payload_len, 0) as u64;
566                    let (reason, guard) = self.message_memory_manager.add(payload_len);
567                    if let Some(reason) = reason {
568                        // Release the memory ASAP.
569                        drop(guard);
570                        self.stream.skip_body().await?;
571                        return Ok((FeMessage::ServerThrottle(reason), None));
572                    }
573                    guard
574                } else {
575                    None
576                };
577                let message = self.stream.read_body().await?;
578                Ok((message, guard))
579            }
580        }
581    }
582
583    /// Writes a `ReadyForQuery` message to the client without flushing.
584    fn ready_for_query(&mut self) -> io::Result<()> {
585        self.stream.write_no_flush(&BeMessage::ReadyForQuery(
586            self.session
587                .as_ref()
588                .map(|s| s.transaction_status())
589                .unwrap_or(TransactionStatus::Idle),
590        ))
591    }
592
593    async fn process_gss_msg(&mut self) -> PsqlResult<()> {
594        // We don't support GSSAPI, so we just say no gracefully.
595        self.stream.write(&BeMessage::EncryptionResponseNo).await?;
596        Ok(())
597    }
598
599    async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
600        if let Some(context) = self.tls_context.as_ref() {
601            // If got and ssl context, say yes for ssl connection.
602            // Construct ssl stream and replace with current one.
603            self.stream.write(&BeMessage::EncryptionResponseSsl).await?;
604            self.stream.upgrade_to_ssl(context).await?;
605        } else {
606            // If no, say no for encryption.
607            self.stream.write(&BeMessage::EncryptionResponseNo).await?;
608        }
609
610        Ok(())
611    }
612
613    fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
614        let db_name = msg
615            .config
616            .get("database")
617            .cloned()
618            .unwrap_or_else(|| "dev".to_owned());
619        let user_name = msg
620            .config
621            .get("user")
622            .cloned()
623            .unwrap_or_else(|| "root".to_owned());
624
625        let session = self
626            .session_mgr
627            .connect(&db_name, &user_name, self.peer_addr.clone())
628            .map_err(PsqlError::StartupError)?;
629
630        let application_name = msg.config.get("application_name");
631        if let Some(application_name) = application_name {
632            session
633                .set_config("application_name", application_name.clone())
634                .map_err(PsqlError::StartupError)?;
635        }
636
637        match session.user_authenticator() {
638            UserAuthenticator::None => {
639                self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
640
641                // Cancel request need this for identify and verification. According to postgres
642                // doc, it should be written to buffer after receive AuthenticationOk.
643                self.stream
644                    .write_no_flush(&BeMessage::BackendKeyData(session.id()))?;
645
646                self.stream.write_no_flush(&BeMessage::ParameterStatus(
647                    BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
648                ))?;
649                self.stream
650                    .write_parameter_status_msg_no_flush(&ParameterStatus {
651                        application_name: application_name.cloned(),
652                    })?;
653                self.ready_for_query()?;
654            }
655            UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth(_) => {
656                self.stream
657                    .write_no_flush(&BeMessage::AuthenticationCleartextPassword)?;
658            }
659            UserAuthenticator::Md5WithSalt { salt, .. } => {
660                self.stream
661                    .write_no_flush(&BeMessage::AuthenticationMd5Password(salt))?;
662            }
663        }
664
665        self.session = Some(session);
666        self.state = PgProtocolState::Regular;
667        Ok(())
668    }
669
670    async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
671        let session = self.session.as_ref().unwrap();
672        let authenticator = session.user_authenticator();
673        authenticator.authenticate(&msg.password).await?;
674        self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
675        self.stream.write_no_flush(&BeMessage::ParameterStatus(
676            BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
677        ))?;
678        self.stream
679            .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
680        self.ready_for_query()?;
681        self.state = PgProtocolState::Regular;
682        Ok(())
683    }
684
685    fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
686        let session_id = (m.target_process_id, m.target_secret_key);
687        tracing::trace!("cancel query in session: {:?}", session_id);
688        self.session_mgr.cancel_queries_in_session(session_id);
689        self.session_mgr.cancel_creating_jobs_in_session(session_id);
690        self.stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
691        Ok(())
692    }
693
694    async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
695        let truncated_sql = record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
696        let session = self.session.clone().unwrap();
697
698        session.check_idle_in_transaction_timeout()?;
699        // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
700        let _exec_context_guard = session.init_exec_context(truncated_sql.into());
701        self.inner_process_query_msg(sql, session.clone()).await
702    }
703
704    async fn inner_process_query_msg(
705        &mut self,
706        sql: Arc<str>,
707        session: Arc<SM::Session>,
708    ) -> PsqlResult<()> {
709        // Parse sql.
710        let stmts =
711            Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
712        // The following inner_process_query_msg_one_stmt can be slow. Release potential large String early.
713        drop(sql);
714        if stmts.is_empty() {
715            self.stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
716        }
717
718        // Execute multiple statements in simple query. KISS later.
719        for stmt in stmts {
720            self.inner_process_query_msg_one_stmt(stmt, session.clone())
721                .await?;
722        }
723        // Put this line inside the for loop above will lead to unfinished/stuck regress test...Not
724        // sure the reason.
725        self.ready_for_query()?;
726        Ok(())
727    }
728
729    async fn inner_process_query_msg_one_stmt(
730        &mut self,
731        stmt: Statement,
732        session: Arc<SM::Session>,
733    ) -> PsqlResult<()> {
734        let session = session.clone();
735
736        // execute query
737        let res = session.clone().run_one_query(stmt, Format::Text).await;
738
739        // Take all remaining notices (if any) and send them before `CommandComplete`.
740        while let Some(notice) = session.next_notice().now_or_never() {
741            self.stream
742                .write_no_flush(&BeMessage::NoticeResponse(&notice))?;
743        }
744
745        let mut res = res.map_err(PsqlError::SimpleQueryError)?;
746
747        for notice in res.notices() {
748            self.stream
749                .write_no_flush(&BeMessage::NoticeResponse(notice))?;
750        }
751
752        let status = res.status();
753        if let Some(ref application_name) = status.application_name {
754            self.stream.write_no_flush(&BeMessage::ParameterStatus(
755                BeParameterStatusMessage::ApplicationName(application_name),
756            ))?;
757        }
758
759        if res.is_query() {
760            self.stream
761                .write_no_flush(&BeMessage::RowDescription(&res.row_desc()))?;
762
763            let mut rows_cnt = 0;
764
765            while let Some(row_set) = res.values_stream().next().await {
766                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
767                for row in row_set {
768                    self.stream.write_no_flush(&BeMessage::DataRow(&row))?;
769                    rows_cnt += 1;
770                }
771            }
772
773            // Run the callback before sending the `CommandComplete` message.
774            res.run_callback().await?;
775
776            self.stream
777                .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
778                    stmt_type: res.stmt_type(),
779                    rows_cnt,
780                }))?;
781        } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
782            let first_row_set = res.values_stream().next().await;
783            let first_row_set = match first_row_set {
784                None => {
785                    return Err(PsqlError::Uncategorized(
786                        anyhow::anyhow!("no affected rows in output").into(),
787                    ));
788                }
789                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
790            };
791            let affected_rows_str = first_row_set[0].values()[0]
792                .as_ref()
793                .expect("compute node should return affected rows in output");
794
795            assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
796            let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
797                .unwrap()
798                .parse()
799                .unwrap_or_default();
800
801            // Run the callback before sending the `CommandComplete` message.
802            res.run_callback().await?;
803
804            self.stream
805                .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
806                    stmt_type: res.stmt_type(),
807                    rows_cnt: affected_rows_cnt,
808                }))?;
809        } else {
810            // Run the callback before sending the `CommandComplete` message.
811            res.run_callback().await?;
812
813            self.stream
814                .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
815                    stmt_type: res.stmt_type(),
816                    rows_cnt: 0,
817                }))?;
818        }
819
820        Ok(())
821    }
822
823    fn process_terminate(&mut self) {
824        self.is_terminate = true;
825    }
826
827    fn process_health_check(&mut self) {
828        tracing::debug!("health check");
829        self.is_terminate = true;
830    }
831
832    async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
833        let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
834        record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
835        let session = self.session.clone().unwrap();
836        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
837        let type_ids = std::mem::take(&mut msg.type_ids);
838        // The inner_process_parse_msg can be slow. Release potential large FeParseMessage early.
839        drop(msg);
840        self.inner_process_parse_msg(session, sql, statement_name, type_ids)
841            .await?;
842        Ok(())
843    }
844
845    async fn inner_process_parse_msg(
846        &mut self,
847        session: Arc<SM::Session>,
848        sql: Arc<str>,
849        statement_name: String,
850        type_ids: Vec<i32>,
851    ) -> PsqlResult<()> {
852        if statement_name.is_empty() {
853            // Remove the unnamed prepare statement first, in case the unsupported sql binds a wrong
854            // prepare statement.
855            self.unnamed_prepare_statement.take();
856        } else if self.prepare_statement_store.contains_key(&statement_name) {
857            return Err(PsqlError::ExtendedPrepareError(
858                "Duplicated statement name".into(),
859            ));
860        }
861
862        let stmt = {
863            let stmts = Parser::parse_sql(&sql)
864                .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
865            drop(sql);
866            if stmts.len() > 1 {
867                return Err(PsqlError::ExtendedPrepareError(
868                    "Only one statement is allowed in extended query mode".into(),
869                ));
870            }
871
872            stmts.into_iter().next()
873        };
874
875        let param_types: Vec<Option<DataType>> = type_ids
876            .iter()
877            .map(|&id| {
878                // 0 means unspecified type
879                // ref: https://www.postgresql.org/docs/15/protocol-message-formats.html#:~:text=Placing%20a%20zero%20here%20is%20equivalent%20to%20leaving%20the%20type%20unspecified.
880                if id == 0 {
881                    Ok(None)
882                } else {
883                    DataType::from_oid(id)
884                        .map(Some)
885                        .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
886                }
887            })
888            .try_collect()?;
889
890        let prepare_statement = session
891            .parse(stmt, param_types)
892            .await
893            .map_err(PsqlError::ExtendedPrepareError)?;
894
895        if statement_name.is_empty() {
896            self.unnamed_prepare_statement.replace(prepare_statement);
897        } else {
898            self.prepare_statement_store
899                .insert(statement_name.clone(), prepare_statement);
900        }
901
902        self.statement_portal_dependency
903            .entry(statement_name)
904            .or_default()
905            .clear();
906
907        self.stream.write_no_flush(&BeMessage::ParseComplete)?;
908        Ok(())
909    }
910
911    fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
912        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
913        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
914        let session = self.session.clone().unwrap();
915
916        if self.portal_store.contains_key(&portal_name) {
917            return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
918        }
919
920        let prepare_statement = self.get_statement(&statement_name)?;
921
922        let result_formats = msg
923            .result_format_codes
924            .iter()
925            .map(|&format_code| Format::from_i16(format_code))
926            .try_collect()?;
927        let param_formats = msg
928            .param_format_codes
929            .iter()
930            .map(|&format_code| Format::from_i16(format_code))
931            .try_collect()?;
932
933        let portal = session
934            .bind(prepare_statement, msg.params, param_formats, result_formats)
935            .map_err(PsqlError::Uncategorized)?;
936
937        if portal_name.is_empty() {
938            self.result_cache.remove(&portal_name);
939            self.unnamed_portal.replace(portal);
940        } else {
941            assert!(
942                !self.result_cache.contains_key(&portal_name),
943                "Named portal never can be overridden."
944            );
945            self.portal_store.insert(portal_name.clone(), portal);
946        }
947
948        self.statement_portal_dependency
949            .get_mut(&statement_name)
950            .unwrap()
951            .push(portal_name);
952
953        self.stream.write_no_flush(&BeMessage::BindComplete)?;
954        Ok(())
955    }
956
957    async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
958        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
959        let row_max = msg.max_rows as usize;
960        drop(msg);
961        let session = self.session.clone().unwrap();
962
963        match self.result_cache.remove(&portal_name) {
964            Some(mut result_cache) => {
965                assert!(self.portal_store.contains_key(&portal_name));
966
967                let is_cosume_completed =
968                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
969
970                if !is_cosume_completed {
971                    self.result_cache.insert(portal_name, result_cache);
972                }
973            }
974            _ => {
975                let portal = self.get_portal(&portal_name)?;
976                let sql = format!("{}", portal);
977                let truncated_sql =
978                    record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
979                drop(sql);
980
981                session.check_idle_in_transaction_timeout()?;
982                // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
983                let _exec_context_guard = session.init_exec_context(truncated_sql.into());
984                let result = session.clone().execute(portal).await;
985
986                let pg_response = result.map_err(PsqlError::ExtendedExecuteError)?;
987                let mut result_cache = ResultCache::new(pg_response);
988                let is_consume_completed =
989                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
990                if !is_consume_completed {
991                    self.result_cache.insert(portal_name, result_cache);
992                }
993            }
994        }
995
996        Ok(())
997    }
998
999    fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1000        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1001        let session = self.session.clone().unwrap();
1002        //  b'S' => Statement
1003        //  b'P' => Portal
1004
1005        assert!(msg.kind == b'S' || msg.kind == b'P');
1006        if msg.kind == b'S' {
1007            let prepare_statement = self.get_statement(&name)?;
1008
1009            let (param_types, row_descriptions) = self
1010                .session
1011                .clone()
1012                .unwrap()
1013                .describe_statement(prepare_statement)
1014                .map_err(PsqlError::Uncategorized)?;
1015            self.stream
1016                .write_no_flush(&BeMessage::ParameterDescription(
1017                    &param_types.iter().map(|t| t.to_oid()).collect_vec(),
1018                ))?;
1019
1020            if row_descriptions.is_empty() {
1021                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1022                // return NoData message if the statement is not a query.
1023                self.stream.write_no_flush(&BeMessage::NoData)?;
1024            } else {
1025                self.stream
1026                    .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?;
1027            }
1028        } else if msg.kind == b'P' {
1029            let portal = self.get_portal(&name)?;
1030
1031            let row_descriptions = session
1032                .describe_portal(portal)
1033                .map_err(PsqlError::Uncategorized)?;
1034
1035            if row_descriptions.is_empty() {
1036                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1037                // return NoData message if the statement is not a query.
1038                self.stream.write_no_flush(&BeMessage::NoData)?;
1039            } else {
1040                self.stream
1041                    .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?;
1042            }
1043        }
1044        Ok(())
1045    }
1046
1047    fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1048        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1049        assert!(msg.kind == b'S' || msg.kind == b'P');
1050        if msg.kind == b'S' {
1051            if name.is_empty() {
1052                self.unnamed_prepare_statement = None;
1053            } else {
1054                self.prepare_statement_store.remove(&name);
1055            }
1056            for portal_name in self
1057                .statement_portal_dependency
1058                .remove(&name)
1059                .unwrap_or_default()
1060            {
1061                self.remove_portal(&portal_name);
1062            }
1063        } else if msg.kind == b'P' {
1064            self.remove_portal(&name);
1065        }
1066        self.stream.write_no_flush(&BeMessage::CloseComplete)?;
1067        Ok(())
1068    }
1069
1070    fn remove_portal(&mut self, portal_name: &str) {
1071        if portal_name.is_empty() {
1072            self.unnamed_portal = None;
1073        } else {
1074            self.portal_store.remove(portal_name);
1075        }
1076        self.result_cache.remove(portal_name);
1077    }
1078
1079    fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1080        if portal_name.is_empty() {
1081            Ok(self
1082                .unnamed_portal
1083                .as_ref()
1084                .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1085                .clone())
1086        } else {
1087            Ok(self
1088                .portal_store
1089                .get(portal_name)
1090                .ok_or_else(|| {
1091                    PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1092                })?
1093                .clone())
1094        }
1095    }
1096
1097    fn get_statement(
1098        &self,
1099        statement_name: &str,
1100    ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1101        if statement_name.is_empty() {
1102            Ok(self
1103                .unnamed_prepare_statement
1104                .as_ref()
1105                .ok_or_else(|| {
1106                    PsqlError::Uncategorized("unnamed prepare statement not found".into())
1107                })?
1108                .clone())
1109        } else {
1110            Ok(self
1111                .prepare_statement_store
1112                .get(statement_name)
1113                .ok_or_else(|| {
1114                    PsqlError::Uncategorized(
1115                        format!("Prepare statement {} not found", statement_name).into(),
1116                    )
1117                })?
1118                .clone())
1119        }
1120    }
1121}
1122
1123enum PgStreamInner<S> {
1124    /// Used for the intermediate state when converting from unencrypted to ssl stream.
1125    Placeholder,
1126    /// An unencrypted stream.
1127    Unencrypted(S),
1128    /// An ssl stream.
1129    Ssl(SslStream<S>),
1130}
1131
1132/// Trait for a byte stream that can be used for pg protocol.
1133pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1134impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1135
1136/// Wraps a byte stream and read/write pg messages.
1137///
1138/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
1139/// so that it can be used to write messages concurrently without interference.
1140pub struct PgStream<S> {
1141    /// The underlying stream.
1142    stream: Arc<Mutex<PgStreamInner<S>>>,
1143    /// Write into buffer before flush to stream.
1144    write_buf: BytesMut,
1145    read_header: Option<FeMessageHeader>,
1146}
1147
1148impl<S> PgStream<S> {
1149    /// Create a new `PgStream` with the given stream and default write buffer capacity.
1150    pub fn new(stream: S) -> Self {
1151        const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1152
1153        Self {
1154            stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1155            write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1156            read_header: None,
1157        }
1158    }
1159}
1160
1161impl<S> Clone for PgStream<S> {
1162    fn clone(&self) -> Self {
1163        Self {
1164            stream: Arc::clone(&self.stream),
1165            write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1166            read_header: self.read_header.clone(),
1167        }
1168    }
1169}
1170
1171/// At present there is a hard-wired set of parameters for which
1172/// ParameterStatus will be generated: they are:
1173///
1174///  * `server_version`
1175///  * `server_encoding`
1176///  * `client_encoding`
1177///  * `application_name`
1178///  * `is_superuser`
1179///  * `session_authorization`
1180///  * `DateStyle`
1181///  * `IntervalStyle`
1182///  * `TimeZone`
1183///  * `integer_datetimes`
1184///  * `standard_conforming_string`
1185///
1186/// See: <https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC>.
1187#[derive(Debug, Default, Clone)]
1188pub struct ParameterStatus {
1189    pub application_name: Option<String>,
1190}
1191
1192impl<S> PgStream<S>
1193where
1194    S: PgByteStream,
1195{
1196    async fn read_startup(&mut self) -> io::Result<FeMessage> {
1197        let mut stream = self.stream.lock().await;
1198        match &mut *stream {
1199            PgStreamInner::Placeholder => unreachable!(),
1200            PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1201            PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1202        }
1203    }
1204
1205    async fn read_header(&mut self) -> io::Result<()> {
1206        let mut stream = self.stream.lock().await;
1207        match &mut *stream {
1208            PgStreamInner::Placeholder => unreachable!(),
1209            PgStreamInner::Unencrypted(stream) => {
1210                self.read_header = Some(FeMessage::read_header(stream).await?);
1211                Ok(())
1212            }
1213            PgStreamInner::Ssl(ssl_stream) => {
1214                self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1215                Ok(())
1216            }
1217        }
1218    }
1219
1220    async fn read_body(&mut self) -> io::Result<FeMessage> {
1221        let mut stream = self.stream.lock().await;
1222        let header = self
1223            .read_header
1224            .take()
1225            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1226        match &mut *stream {
1227            PgStreamInner::Placeholder => unreachable!(),
1228            PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1229            PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1230        }
1231    }
1232
1233    async fn skip_body(&mut self) -> io::Result<()> {
1234        let mut stream = self.stream.lock().await;
1235        let header = self
1236            .read_header
1237            .take()
1238            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1239        match &mut *stream {
1240            PgStreamInner::Placeholder => unreachable!(),
1241            PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1242            PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1243        }
1244    }
1245
1246    fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1247        self.write_no_flush(&BeMessage::ParameterStatus(
1248            BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1249        ))?;
1250        self.write_no_flush(&BeMessage::ParameterStatus(
1251            BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1252        ))?;
1253        self.write_no_flush(&BeMessage::ParameterStatus(
1254            BeParameterStatusMessage::ServerVersion(PG_VERSION),
1255        ))?;
1256        if let Some(application_name) = &status.application_name {
1257            self.write_no_flush(&BeMessage::ParameterStatus(
1258                BeParameterStatusMessage::ApplicationName(application_name),
1259            ))?;
1260        }
1261        Ok(())
1262    }
1263
1264    pub fn write_no_flush(&mut self, message: &BeMessage<'_>) -> io::Result<()> {
1265        BeMessage::write(&mut self.write_buf, message)
1266    }
1267
1268    async fn write(&mut self, message: &BeMessage<'_>) -> io::Result<()> {
1269        self.write_no_flush(message)?;
1270        self.flush().await?;
1271        Ok(())
1272    }
1273
1274    async fn flush(&mut self) -> io::Result<()> {
1275        let mut stream = self.stream.lock().await;
1276        match &mut *stream {
1277            PgStreamInner::Placeholder => unreachable!(),
1278            PgStreamInner::Unencrypted(stream) => {
1279                stream.write_all(&self.write_buf).await?;
1280                stream.flush().await?;
1281            }
1282            PgStreamInner::Ssl(ssl_stream) => {
1283                ssl_stream.write_all(&self.write_buf).await?;
1284                ssl_stream.flush().await?;
1285            }
1286        }
1287        self.write_buf.clear();
1288        Ok(())
1289    }
1290}
1291
1292impl<S> PgStream<S>
1293where
1294    S: PgByteStream,
1295{
1296    /// Convert the underlying stream to ssl stream based on the given context.
1297    async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1298        let mut stream = self.stream.lock().await;
1299
1300        match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1301            PgStreamInner::Unencrypted(unencrypted_stream) => {
1302                let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1303                let mut ssl_stream =
1304                    tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1305
1306                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1307                    tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1308                    let _ = ssl_stream.shutdown().await;
1309                    return Err(e.into());
1310                }
1311
1312                *stream = PgStreamInner::Ssl(ssl_stream);
1313            }
1314            PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1315            PgStreamInner::Placeholder => unreachable!(),
1316        }
1317
1318        Ok(())
1319    }
1320}
1321
1322fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1323    let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1324
1325    let key_path = &tls_config.key;
1326    let cert_path = &tls_config.cert;
1327
1328    // Build ssl acceptor according to the config.
1329    // Now we set every verify to true.
1330    acceptor
1331        .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1332        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1333    acceptor
1334        .set_ca_file(cert_path)
1335        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1336    acceptor
1337        .set_certificate_chain_file(cert_path)
1338        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1339    let acceptor = acceptor.build();
1340
1341    Ok(acceptor.into_context())
1342}
1343
1344pub mod truncated_fmt {
1345    use std::fmt::*;
1346
1347    struct TruncatedFormatter<'a, 'b> {
1348        remaining: usize,
1349        finished: bool,
1350        f: &'a mut Formatter<'b>,
1351    }
1352    impl Write for TruncatedFormatter<'_, '_> {
1353        fn write_str(&mut self, s: &str) -> Result {
1354            if self.finished {
1355                return Ok(());
1356            }
1357
1358            if self.remaining < s.len() {
1359                let actual = s.floor_char_boundary(self.remaining);
1360                self.f.write_str(&s[0..actual])?;
1361                self.remaining -= actual;
1362                self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1363                self.finished = true; // so that ...(truncated) is printed exactly once
1364            } else {
1365                self.f.write_str(s)?;
1366                self.remaining -= s.len();
1367            }
1368            Ok(())
1369        }
1370    }
1371
1372    pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1373
1374    impl<T> Debug for TruncatedFmt<'_, T>
1375    where
1376        T: Debug,
1377    {
1378        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1379            TruncatedFormatter {
1380                remaining: self.1,
1381                finished: false,
1382                f,
1383            }
1384            .write_fmt(format_args!("{:?}", self.0))
1385        }
1386    }
1387
1388    impl<T> Display for TruncatedFmt<'_, T>
1389    where
1390        T: Display,
1391    {
1392        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1393            TruncatedFormatter {
1394                remaining: self.1,
1395                finished: false,
1396                f,
1397            }
1398            .write_fmt(format_args!("{}", self.0))
1399        }
1400    }
1401
1402    #[cfg(test)]
1403    mod tests {
1404        use super::*;
1405
1406        #[test]
1407        fn test_trunc_utf8() {
1408            assert_eq!(
1409                format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1410                "select '...(truncated,14)",
1411            );
1412        }
1413    }
1414}
1415
1416#[cfg(test)]
1417mod tests {
1418    use std::collections::HashSet;
1419
1420    use super::*;
1421
1422    #[test]
1423    fn test_redact_parsable_sql() {
1424        let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1425        let sql = r"
1426        create source temp (k bigint, v varchar) with (
1427            connector = 'datagen',
1428            v1 = 123,
1429            v2 = 'with',
1430            v3 = false,
1431            v4 = '',
1432        ) FORMAT plain ENCODE json (a='1',b='2')
1433        ";
1434        assert_eq!(
1435            redact_sql(sql, keywords),
1436            "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1437        );
1438    }
1439}