pgwire/
pg_protocol.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::panic::FutureCatchUnwindExt;
33use risingwave_common::util::query_log::*;
34use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
35use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
36use risingwave_sqlparser::parser::Parser;
37use thiserror_ext::AsReport;
38use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
39use tokio::sync::Mutex;
40use tokio_openssl::SslStream;
41use tracing::Instrument;
42
43use crate::error::{PsqlError, PsqlResult};
44use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
45use crate::net::AddressRef;
46use crate::pg_extended::ResultCache;
47use crate::pg_message::{
48    BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
49    FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
50    FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
51};
52use crate::pg_server::{Session, SessionManager, UserAuthenticator};
53use crate::types::Format;
54
55/// Truncates query log if it's longer than `RW_QUERY_LOG_TRUNCATE_LEN`, to avoid log file being too
56/// large.
57static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
58    LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
59        Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
60        _ => {
61            if cfg!(debug_assertions) {
62                65536
63            } else {
64                1024
65            }
66        }
67    });
68
69tokio::task_local! {
70    /// The current session. Concrete type is erased for different session implementations.
71    pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
72}
73
74/// The state machine for each psql connection.
75/// Read pg messages from tcp stream and write results back.
76pub struct PgProtocol<S, SM>
77where
78    SM: SessionManager,
79{
80    /// Used for write/read pg messages.
81    stream: PgStream<S>,
82    /// Current states of pg connection.
83    state: PgProtocolState,
84    /// Whether the connection is terminated.
85    is_terminate: bool,
86
87    session_mgr: Arc<SM>,
88    session: Option<Arc<SM::Session>>,
89
90    result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
91    unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
92    prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
93    unnamed_portal: Option<<SM::Session as Session>::Portal>,
94    portal_store: HashMap<String, <SM::Session as Session>::Portal>,
95    // Used to store the dependency of portal and prepare statement.
96    // When we close a prepare statement, we need to close all the portals that depend on it.
97    statement_portal_dependency: HashMap<String, Vec<String>>,
98
99    // Used for ssl connection.
100    // If None, not expected to build ssl connection (panic).
101    tls_context: Option<SslContext>,
102
103    // Used in extended query protocol. When encounter error in extended query, we need to ignore
104    // the following message util sync message.
105    ignore_util_sync: bool,
106
107    // Client Address
108    peer_addr: AddressRef,
109
110    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
111    message_memory_manager: MessageMemoryManagerRef,
112}
113
114/// Configures TLS encryption for connections.
115#[derive(Debug, Clone)]
116pub struct TlsConfig {
117    /// The path to the TLS certificate.
118    pub cert: String,
119    /// The path to the TLS key.
120    pub key: String,
121}
122
123impl TlsConfig {
124    pub fn new_default() -> Option<Self> {
125        let cert = std::env::var("RW_SSL_CERT").ok()?;
126        let key = std::env::var("RW_SSL_KEY").ok()?;
127        tracing::info!("RW_SSL_CERT={}, RW_SSL_KEY={}", cert, key);
128        Some(Self { cert, key })
129    }
130}
131
132impl<S, SM> Drop for PgProtocol<S, SM>
133where
134    SM: SessionManager,
135{
136    fn drop(&mut self) {
137        if let Some(session) = &self.session {
138            // Clear the session in session manager.
139            self.session_mgr.end_session(session);
140        }
141    }
142}
143
144/// States flow happened from top to down.
145enum PgProtocolState {
146    Startup,
147    Regular,
148}
149
150/// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations).
151///
152/// PG protocol strings are always C strings.
153pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
154    let without_null = if b.last() == Some(&0) {
155        &b[..b.len() - 1]
156    } else {
157        &b[..]
158    };
159    std::str::from_utf8(without_null)
160}
161
162/// Record `sql` in the current tracing span.
163fn record_sql_in_span(
164    sql: &str,
165    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
166) -> String {
167    let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
168        && !keywords.is_empty()
169    {
170        redact_sql(sql, keywords)
171    } else {
172        sql.to_owned()
173    };
174    let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
175    tracing::Span::current().record("sql", tracing::field::display(&truncated));
176    truncated.to_string()
177}
178
179/// Redacts SQL options. Data in DML is not redacted.
180fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
181    match Parser::parse_sql(sql) {
182        Ok(sqls) => sqls
183            .into_iter()
184            .map(|sql| sql.to_redacted_string(keywords.clone()))
185            .join(";"),
186        Err(_) => sql.to_owned(),
187    }
188}
189
190#[derive(Clone)]
191pub struct ConnectionContext {
192    pub tls_config: Option<TlsConfig>,
193    pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
194    pub message_memory_manager: MessageMemoryManagerRef,
195}
196
197impl<S, SM> PgProtocol<S, SM>
198where
199    S: PgByteStream,
200    SM: SessionManager,
201{
202    pub fn new(
203        stream: S,
204        session_mgr: Arc<SM>,
205        peer_addr: AddressRef,
206        context: ConnectionContext,
207    ) -> Self {
208        let ConnectionContext {
209            tls_config,
210            redact_sql_option_keywords,
211            message_memory_manager,
212        } = context;
213        Self {
214            stream: PgStream::new(stream),
215            is_terminate: false,
216            state: PgProtocolState::Startup,
217            session_mgr,
218            session: None,
219            tls_context: tls_config
220                .as_ref()
221                .and_then(|e| build_ssl_ctx_from_config(e).ok()),
222            result_cache: Default::default(),
223            unnamed_prepare_statement: Default::default(),
224            prepare_statement_store: Default::default(),
225            unnamed_portal: Default::default(),
226            portal_store: Default::default(),
227            statement_portal_dependency: Default::default(),
228            ignore_util_sync: false,
229            peer_addr,
230            redact_sql_option_keywords,
231            message_memory_manager,
232        }
233    }
234
235    /// Run the protocol to serve the connection.
236    pub async fn run(&mut self) {
237        let mut notice_fut = None;
238
239        loop {
240            // Once a session is present, create a future to subscribe and send notices asynchronously.
241            if notice_fut.is_none()
242                && let Some(session) = self.session.clone()
243            {
244                let mut stream = self.stream.clone();
245                notice_fut = Some(Box::pin(async move {
246                    loop {
247                        let notice = session.next_notice().await;
248                        if let Err(e) = stream.write(&BeMessage::NoticeResponse(&notice)).await {
249                            tracing::error!(error = %e.as_report(), notice, "failed to send notice");
250                        }
251                    }
252                }));
253            }
254
255            // Read and process messages.
256            let process = std::pin::pin!(async {
257                let (msg, _memory_guard) = match self.read_message().await {
258                    Ok(msg) => msg,
259                    Err(e) => {
260                        tracing::error!(error = %e.as_report(), "error when reading message");
261                        return true; // terminate the connection
262                    }
263                };
264                tracing::trace!(?msg, "received message");
265                self.process(msg).await
266            });
267
268            let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
269                tokio::select! {
270                    _ = notice_fut => unreachable!(),
271                    terminated = process => terminated,
272                }
273            } else {
274                process.await
275            };
276
277            if terminated {
278                break;
279            }
280        }
281    }
282
283    /// Processes one message. Returns true if the connection is terminated.
284    pub async fn process(&mut self, msg: FeMessage) -> bool {
285        self.do_process(msg).await.is_none() || self.is_terminate
286    }
287
288    /// The root tracing span for processing a message. The target of the span is
289    /// [`PGWIRE_ROOT_SPAN_TARGET`].
290    ///
291    /// This is used to provide context for the (slow) query logs and traces.
292    ///
293    /// The span is only effective if there's a current session and the message is
294    /// query-related. Otherwise, `Span::none()` is returned.
295    fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
296        let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
297            return tracing::Span::none();
298        };
299
300        let mode = match msg {
301            FeMessage::Query(_) => "simple query",
302            FeMessage::Parse(_) => "extended query parse",
303            FeMessage::Execute(_) => "extended query execute",
304            _ => return tracing::Span::none(),
305        };
306
307        tracing::info_span!(
308            target: PGWIRE_ROOT_SPAN_TARGET,
309            "handle_query",
310            mode,
311            session_id,
312            sql = tracing::field::Empty, // record SQL later in each `process` call
313        )
314    }
315
316    /// Return type `Option<()>` is essentially a bool, but allows `?` for early return.
317    /// - `None` means to terminate the current connection
318    /// - `Some(())` means to continue processing the next message
319    async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
320        let span = self.root_span_for_msg(&msg);
321        let weak_session = self
322            .session
323            .as_ref()
324            .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
325
326        // Processing the message itself.
327        //
328        // Note: pin the future to avoid stack overflow as we'll wrap it multiple times
329        // in the following code.
330        let fut = Box::pin(self.do_process_inner(msg));
331
332        // Set the current session as the context when processing the message, if exists.
333        let fut = async move {
334            if let Some(session) = weak_session {
335                CURRENT_SESSION.scope(session, fut).await
336            } else {
337                fut.await
338            }
339        };
340
341        // Catch unwind.
342        let fut = async move {
343            AssertUnwindSafe(fut)
344                .rw_catch_unwind()
345                .await
346                .unwrap_or_else(|payload| {
347                    Err(PsqlError::Panic(
348                        panic_message::panic_message(&payload).to_owned(),
349                    ))
350                })
351        };
352
353        // Slow query log.
354        let fut = async move {
355            let period = *SLOW_QUERY_LOG_PERIOD;
356            let mut fut = std::pin::pin!(fut);
357            let mut elapsed = Duration::ZERO;
358
359            // Report the SQL in the log periodically if the query is slow.
360            loop {
361                match tokio::time::timeout(period, &mut fut).await {
362                    Ok(result) => break result,
363                    Err(_) => {
364                        elapsed += period;
365                        tracing::info!(
366                            target: PGWIRE_SLOW_QUERY_LOG,
367                            elapsed = %format_args!("{}ms", elapsed.as_millis()),
368                            "slow query"
369                        );
370                    }
371                }
372            }
373        };
374
375        // Query log.
376        let fut = async move {
377            let start = Instant::now();
378            let result = fut.await;
379            let elapsed = start.elapsed();
380
381            // Always log if an error occurs.
382            // Note: all messages will be processed through this code path, making it the
383            //       only necessary place to log errors.
384            if let Err(error) = &result {
385                if cfg!(debug_assertions) && !Deployment::current().is_ci() {
386                    // For local debugging, we print the error with backtrace.
387                    // It's useful only when:
388                    // - no additional context is added to the error
389                    // - backtrace is captured in the error
390                    // - backtrace is not printed in the middle
391                    tracing::error!(error = ?error.as_report(), "error when process message");
392                } else {
393                    tracing::error!(error = %error.as_report(), "error when process message");
394                }
395            }
396
397            // Log to optionally-enabled target `PGWIRE_QUERY_LOG`.
398            // Only log if we're currently in a tracing span set in `span_for_msg`.
399            if !tracing::Span::current().is_none() {
400                tracing::info!(
401                    target: PGWIRE_QUERY_LOG,
402                    status = if result.is_ok() { "ok" } else { "err" },
403                    time = %format_args!("{}ms", elapsed.as_millis()),
404                );
405            }
406
407            result
408        };
409
410        // Tracing span.
411        let fut = fut.instrument(span);
412
413        // Execute the future and handle the error.
414        match fut.await {
415            Ok(()) => Some(()),
416            Err(e) => {
417                match e {
418                    PsqlError::IoError(io_err) => {
419                        if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
420                            return None;
421                        }
422                    }
423
424                    PsqlError::SslError(_) => {
425                        // For ssl error, because the stream has already been consumed, so there is
426                        // no way to write more message.
427                        return None;
428                    }
429
430                    PsqlError::StartupError(_) | PsqlError::PasswordError => {
431                        self.stream
432                            .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
433                            .ok()?;
434                        let _ = self.stream.flush().await;
435                        return None;
436                    }
437
438                    PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
439                        self.stream
440                            .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
441                            .ok()?;
442                        self.ready_for_query().ok()?;
443                    }
444
445                    PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
446                        self.stream
447                            .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
448                            .ok()?;
449                        let _ = self.stream.flush().await;
450
451                        // 1. Catching the panic during message processing may leave the session in an
452                        // inconsistent state. We forcefully close the connection (then end the
453                        // session) here for safety.
454                        // 2. Idle in transaction timeout should also close the connection.
455                        return None;
456                    }
457
458                    PsqlError::Uncategorized(_)
459                    | PsqlError::ExtendedPrepareError(_)
460                    | PsqlError::ExtendedExecuteError(_) => {
461                        self.stream
462                            .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
463                            .ok()?;
464                    }
465                }
466                let _ = self.stream.flush().await;
467                Some(())
468            }
469        }
470    }
471
472    async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
473        // Ignore util sync message.
474        if self.ignore_util_sync {
475            if let FeMessage::Sync = msg {
476            } else {
477                tracing::trace!("ignore message {:?} until sync.", msg);
478                return Ok(());
479            }
480        }
481
482        match msg {
483            FeMessage::Gss => self.process_gss_msg().await?,
484            FeMessage::Ssl => self.process_ssl_msg().await?,
485            FeMessage::Startup(msg) => self.process_startup_msg(msg)?,
486            FeMessage::Password(msg) => self.process_password_msg(msg).await?,
487            FeMessage::Query(query_msg) => {
488                let sql = Arc::from(query_msg.get_sql()?);
489                // The process_query_msg can be slow. Release potential large FeQueryMessage early.
490                drop(query_msg);
491                self.process_query_msg(sql).await?
492            }
493            FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
494            FeMessage::Terminate => self.process_terminate(),
495            FeMessage::Parse(m) => {
496                if let Err(err) = self.process_parse_msg(m).await {
497                    self.ignore_util_sync = true;
498                    return Err(err);
499                }
500            }
501            FeMessage::Bind(m) => {
502                if let Err(err) = self.process_bind_msg(m) {
503                    self.ignore_util_sync = true;
504                    return Err(err);
505                }
506            }
507            FeMessage::Execute(m) => {
508                if let Err(err) = self.process_execute_msg(m).await {
509                    self.ignore_util_sync = true;
510                    return Err(err);
511                }
512            }
513            FeMessage::Describe(m) => {
514                if let Err(err) = self.process_describe_msg(m) {
515                    self.ignore_util_sync = true;
516                    return Err(err);
517                }
518            }
519            FeMessage::Sync => {
520                self.ignore_util_sync = false;
521                self.ready_for_query()?
522            }
523            FeMessage::Close(m) => {
524                if let Err(err) = self.process_close_msg(m) {
525                    self.ignore_util_sync = true;
526                    return Err(err);
527                }
528            }
529            FeMessage::Flush => {
530                if let Err(err) = self.stream.flush().await {
531                    self.ignore_util_sync = true;
532                    return Err(err.into());
533                }
534            }
535            FeMessage::HealthCheck => self.process_health_check(),
536            FeMessage::ServerThrottle(reason) => match reason {
537                ServerThrottleReason::TooLargeMessage => {
538                    return Err(PsqlError::ServerThrottle(format!(
539                        "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
540                        self.message_memory_manager.max_filter_bytes
541                    )));
542                }
543                ServerThrottleReason::TooManyMemoryUsage => {
544                    return Err(PsqlError::ServerThrottle(format!(
545                        "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
546                        self.message_memory_manager.max_running_bytes
547                    )));
548                }
549            },
550        }
551        self.stream.flush().await?;
552        Ok(())
553    }
554
555    pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
556        match self.state {
557            PgProtocolState::Startup => self
558                .stream
559                .read_startup()
560                .await
561                .map(|message: FeMessage| (message, None)),
562            PgProtocolState::Regular => {
563                self.stream.read_header().await?;
564                let guard = if let Some(ref header) = self.stream.read_header {
565                    let payload_len = std::cmp::max(header.payload_len, 0) as u64;
566                    let (reason, guard) = self.message_memory_manager.add(payload_len);
567                    if let Some(reason) = reason {
568                        // Release the memory ASAP.
569                        drop(guard);
570                        self.stream.skip_body().await?;
571                        return Ok((FeMessage::ServerThrottle(reason), None));
572                    }
573                    guard
574                } else {
575                    None
576                };
577                let message = self.stream.read_body().await?;
578                Ok((message, guard))
579            }
580        }
581    }
582
583    /// Writes a `ReadyForQuery` message to the client without flushing.
584    fn ready_for_query(&mut self) -> io::Result<()> {
585        self.stream.write_no_flush(&BeMessage::ReadyForQuery(
586            self.session
587                .as_ref()
588                .map(|s| s.transaction_status())
589                .unwrap_or(TransactionStatus::Idle),
590        ))
591    }
592
593    async fn process_gss_msg(&mut self) -> PsqlResult<()> {
594        // We don't support GSSAPI, so we just say no gracefully.
595        self.stream.write(&BeMessage::EncryptionResponseNo).await?;
596        Ok(())
597    }
598
599    async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
600        if let Some(context) = self.tls_context.as_ref() {
601            // If got and ssl context, say yes for ssl connection.
602            // Construct ssl stream and replace with current one.
603            self.stream.write(&BeMessage::EncryptionResponseSsl).await?;
604            self.stream.upgrade_to_ssl(context).await?;
605        } else {
606            // If no, say no for encryption.
607            self.stream.write(&BeMessage::EncryptionResponseNo).await?;
608        }
609
610        Ok(())
611    }
612
613    fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
614        let db_name = msg
615            .config
616            .get("database")
617            .cloned()
618            .unwrap_or_else(|| "dev".to_owned());
619        let user_name = msg
620            .config
621            .get("user")
622            .cloned()
623            .unwrap_or_else(|| "root".to_owned());
624
625        let session = self
626            .session_mgr
627            .connect(&db_name, &user_name, self.peer_addr.clone())
628            .map_err(PsqlError::StartupError)?;
629
630        let application_name = msg.config.get("application_name");
631        if let Some(application_name) = application_name {
632            session
633                .set_config("application_name", application_name.clone())
634                .map_err(PsqlError::StartupError)?;
635        }
636
637        match session.user_authenticator() {
638            UserAuthenticator::None => {
639                self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
640
641                // Cancel request need this for identify and verification. According to postgres
642                // doc, it should be written to buffer after receive AuthenticationOk.
643                self.stream
644                    .write_no_flush(&BeMessage::BackendKeyData(session.id()))?;
645
646                self.stream
647                    .write_parameter_status_msg_no_flush(&ParameterStatus {
648                        application_name: application_name.cloned(),
649                    })?;
650                self.ready_for_query()?;
651            }
652            UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth(_) => {
653                self.stream
654                    .write_no_flush(&BeMessage::AuthenticationCleartextPassword)?;
655            }
656            UserAuthenticator::Md5WithSalt { salt, .. } => {
657                self.stream
658                    .write_no_flush(&BeMessage::AuthenticationMd5Password(salt))?;
659            }
660        }
661
662        self.session = Some(session);
663        self.state = PgProtocolState::Regular;
664        Ok(())
665    }
666
667    async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
668        let authenticator = self.session.as_ref().unwrap().user_authenticator();
669        authenticator.authenticate(&msg.password).await?;
670        self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
671        self.stream
672            .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
673        self.ready_for_query()?;
674        self.state = PgProtocolState::Regular;
675        Ok(())
676    }
677
678    fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
679        let session_id = (m.target_process_id, m.target_secret_key);
680        tracing::trace!("cancel query in session: {:?}", session_id);
681        self.session_mgr.cancel_queries_in_session(session_id);
682        self.session_mgr.cancel_creating_jobs_in_session(session_id);
683        self.stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
684        Ok(())
685    }
686
687    async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
688        let truncated_sql = record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
689        let session = self.session.clone().unwrap();
690
691        session.check_idle_in_transaction_timeout()?;
692        // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
693        let _exec_context_guard = session.init_exec_context(truncated_sql.into());
694        self.inner_process_query_msg(sql, session.clone()).await
695    }
696
697    async fn inner_process_query_msg(
698        &mut self,
699        sql: Arc<str>,
700        session: Arc<SM::Session>,
701    ) -> PsqlResult<()> {
702        // Parse sql.
703        let stmts =
704            Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
705        // The following inner_process_query_msg_one_stmt can be slow. Release potential large String early.
706        drop(sql);
707        if stmts.is_empty() {
708            self.stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
709        }
710
711        // Execute multiple statements in simple query. KISS later.
712        for stmt in stmts {
713            self.inner_process_query_msg_one_stmt(stmt, session.clone())
714                .await?;
715        }
716        // Put this line inside the for loop above will lead to unfinished/stuck regress test...Not
717        // sure the reason.
718        self.ready_for_query()?;
719        Ok(())
720    }
721
722    async fn inner_process_query_msg_one_stmt(
723        &mut self,
724        stmt: Statement,
725        session: Arc<SM::Session>,
726    ) -> PsqlResult<()> {
727        let session = session.clone();
728
729        // execute query
730        let res = session.clone().run_one_query(stmt, Format::Text).await;
731
732        // Take all remaining notices (if any) and send them before `CommandComplete`.
733        while let Some(notice) = session.next_notice().now_or_never() {
734            self.stream
735                .write_no_flush(&BeMessage::NoticeResponse(&notice))?;
736        }
737
738        let mut res = res.map_err(PsqlError::SimpleQueryError)?;
739
740        for notice in res.notices() {
741            self.stream
742                .write_no_flush(&BeMessage::NoticeResponse(notice))?;
743        }
744
745        let status = res.status();
746        if let Some(ref application_name) = status.application_name {
747            self.stream.write_no_flush(&BeMessage::ParameterStatus(
748                BeParameterStatusMessage::ApplicationName(application_name),
749            ))?;
750        }
751
752        if res.is_query() {
753            self.stream
754                .write_no_flush(&BeMessage::RowDescription(&res.row_desc()))?;
755
756            let mut rows_cnt = 0;
757
758            while let Some(row_set) = res.values_stream().next().await {
759                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
760                for row in row_set {
761                    self.stream.write_no_flush(&BeMessage::DataRow(&row))?;
762                    rows_cnt += 1;
763                }
764            }
765
766            // Run the callback before sending the `CommandComplete` message.
767            res.run_callback().await?;
768
769            self.stream
770                .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
771                    stmt_type: res.stmt_type(),
772                    rows_cnt,
773                }))?;
774        } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
775            let first_row_set = res.values_stream().next().await;
776            let first_row_set = match first_row_set {
777                None => {
778                    return Err(PsqlError::Uncategorized(
779                        anyhow::anyhow!("no affected rows in output").into(),
780                    ));
781                }
782                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
783            };
784            let affected_rows_str = first_row_set[0].values()[0]
785                .as_ref()
786                .expect("compute node should return affected rows in output");
787
788            assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
789            let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
790                .unwrap()
791                .parse()
792                .unwrap_or_default();
793
794            // Run the callback before sending the `CommandComplete` message.
795            res.run_callback().await?;
796
797            self.stream
798                .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
799                    stmt_type: res.stmt_type(),
800                    rows_cnt: affected_rows_cnt,
801                }))?;
802        } else {
803            // Run the callback before sending the `CommandComplete` message.
804            res.run_callback().await?;
805
806            self.stream
807                .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
808                    stmt_type: res.stmt_type(),
809                    rows_cnt: 0,
810                }))?;
811        }
812
813        Ok(())
814    }
815
816    fn process_terminate(&mut self) {
817        self.is_terminate = true;
818    }
819
820    fn process_health_check(&mut self) {
821        tracing::debug!("health check");
822        self.is_terminate = true;
823    }
824
825    async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
826        let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
827        record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
828        let session = self.session.clone().unwrap();
829        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
830        let type_ids = std::mem::take(&mut msg.type_ids);
831        // The inner_process_parse_msg can be slow. Release potential large FeParseMessage early.
832        drop(msg);
833        self.inner_process_parse_msg(session, sql, statement_name, type_ids)
834            .await?;
835        Ok(())
836    }
837
838    async fn inner_process_parse_msg(
839        &mut self,
840        session: Arc<SM::Session>,
841        sql: Arc<str>,
842        statement_name: String,
843        type_ids: Vec<i32>,
844    ) -> PsqlResult<()> {
845        if statement_name.is_empty() {
846            // Remove the unnamed prepare statement first, in case the unsupported sql binds a wrong
847            // prepare statement.
848            self.unnamed_prepare_statement.take();
849        } else if self.prepare_statement_store.contains_key(&statement_name) {
850            return Err(PsqlError::ExtendedPrepareError(
851                "Duplicated statement name".into(),
852            ));
853        }
854
855        let stmt = {
856            let stmts = Parser::parse_sql(&sql)
857                .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
858            drop(sql);
859            if stmts.len() > 1 {
860                return Err(PsqlError::ExtendedPrepareError(
861                    "Only one statement is allowed in extended query mode".into(),
862                ));
863            }
864
865            stmts.into_iter().next()
866        };
867
868        let param_types: Vec<Option<DataType>> = type_ids
869            .iter()
870            .map(|&id| {
871                // 0 means unspecified type
872                // ref: https://www.postgresql.org/docs/15/protocol-message-formats.html#:~:text=Placing%20a%20zero%20here%20is%20equivalent%20to%20leaving%20the%20type%20unspecified.
873                if id == 0 {
874                    Ok(None)
875                } else {
876                    DataType::from_oid(id)
877                        .map(Some)
878                        .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
879                }
880            })
881            .try_collect()?;
882
883        let prepare_statement = session
884            .parse(stmt, param_types)
885            .await
886            .map_err(PsqlError::ExtendedPrepareError)?;
887
888        if statement_name.is_empty() {
889            self.unnamed_prepare_statement.replace(prepare_statement);
890        } else {
891            self.prepare_statement_store
892                .insert(statement_name.clone(), prepare_statement);
893        }
894
895        self.statement_portal_dependency
896            .entry(statement_name)
897            .or_default()
898            .clear();
899
900        self.stream.write_no_flush(&BeMessage::ParseComplete)?;
901        Ok(())
902    }
903
904    fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
905        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
906        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
907        let session = self.session.clone().unwrap();
908
909        if self.portal_store.contains_key(&portal_name) {
910            return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
911        }
912
913        let prepare_statement = self.get_statement(&statement_name)?;
914
915        let result_formats = msg
916            .result_format_codes
917            .iter()
918            .map(|&format_code| Format::from_i16(format_code))
919            .try_collect()?;
920        let param_formats = msg
921            .param_format_codes
922            .iter()
923            .map(|&format_code| Format::from_i16(format_code))
924            .try_collect()?;
925
926        let portal = session
927            .bind(prepare_statement, msg.params, param_formats, result_formats)
928            .map_err(PsqlError::Uncategorized)?;
929
930        if portal_name.is_empty() {
931            self.result_cache.remove(&portal_name);
932            self.unnamed_portal.replace(portal);
933        } else {
934            assert!(
935                !self.result_cache.contains_key(&portal_name),
936                "Named portal never can be overridden."
937            );
938            self.portal_store.insert(portal_name.clone(), portal);
939        }
940
941        self.statement_portal_dependency
942            .get_mut(&statement_name)
943            .unwrap()
944            .push(portal_name);
945
946        self.stream.write_no_flush(&BeMessage::BindComplete)?;
947        Ok(())
948    }
949
950    async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
951        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
952        let row_max = msg.max_rows as usize;
953        drop(msg);
954        let session = self.session.clone().unwrap();
955
956        match self.result_cache.remove(&portal_name) {
957            Some(mut result_cache) => {
958                assert!(self.portal_store.contains_key(&portal_name));
959
960                let is_cosume_completed =
961                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
962
963                if !is_cosume_completed {
964                    self.result_cache.insert(portal_name, result_cache);
965                }
966            }
967            _ => {
968                let portal = self.get_portal(&portal_name)?;
969                let sql = format!("{}", portal);
970                let truncated_sql =
971                    record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
972                drop(sql);
973
974                session.check_idle_in_transaction_timeout()?;
975                // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
976                let _exec_context_guard = session.init_exec_context(truncated_sql.into());
977                let result = session.clone().execute(portal).await;
978
979                let pg_response = result.map_err(PsqlError::ExtendedExecuteError)?;
980                let mut result_cache = ResultCache::new(pg_response);
981                let is_consume_completed =
982                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
983                if !is_consume_completed {
984                    self.result_cache.insert(portal_name, result_cache);
985                }
986            }
987        }
988
989        Ok(())
990    }
991
992    fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
993        let name = cstr_to_str(&msg.name).unwrap().to_owned();
994        let session = self.session.clone().unwrap();
995        //  b'S' => Statement
996        //  b'P' => Portal
997
998        assert!(msg.kind == b'S' || msg.kind == b'P');
999        if msg.kind == b'S' {
1000            let prepare_statement = self.get_statement(&name)?;
1001
1002            let (param_types, row_descriptions) = self
1003                .session
1004                .clone()
1005                .unwrap()
1006                .describe_statement(prepare_statement)
1007                .map_err(PsqlError::Uncategorized)?;
1008            self.stream
1009                .write_no_flush(&BeMessage::ParameterDescription(
1010                    &param_types.iter().map(|t| t.to_oid()).collect_vec(),
1011                ))?;
1012
1013            if row_descriptions.is_empty() {
1014                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1015                // return NoData message if the statement is not a query.
1016                self.stream.write_no_flush(&BeMessage::NoData)?;
1017            } else {
1018                self.stream
1019                    .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?;
1020            }
1021        } else if msg.kind == b'P' {
1022            let portal = self.get_portal(&name)?;
1023
1024            let row_descriptions = session
1025                .describe_portal(portal)
1026                .map_err(PsqlError::Uncategorized)?;
1027
1028            if row_descriptions.is_empty() {
1029                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1030                // return NoData message if the statement is not a query.
1031                self.stream.write_no_flush(&BeMessage::NoData)?;
1032            } else {
1033                self.stream
1034                    .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?;
1035            }
1036        }
1037        Ok(())
1038    }
1039
1040    fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1041        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1042        assert!(msg.kind == b'S' || msg.kind == b'P');
1043        if msg.kind == b'S' {
1044            if name.is_empty() {
1045                self.unnamed_prepare_statement = None;
1046            } else {
1047                self.prepare_statement_store.remove(&name);
1048            }
1049            for portal_name in self
1050                .statement_portal_dependency
1051                .remove(&name)
1052                .unwrap_or_default()
1053            {
1054                self.remove_portal(&portal_name);
1055            }
1056        } else if msg.kind == b'P' {
1057            self.remove_portal(&name);
1058        }
1059        self.stream.write_no_flush(&BeMessage::CloseComplete)?;
1060        Ok(())
1061    }
1062
1063    fn remove_portal(&mut self, portal_name: &str) {
1064        if portal_name.is_empty() {
1065            self.unnamed_portal = None;
1066        } else {
1067            self.portal_store.remove(portal_name);
1068        }
1069        self.result_cache.remove(portal_name);
1070    }
1071
1072    fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1073        if portal_name.is_empty() {
1074            Ok(self
1075                .unnamed_portal
1076                .as_ref()
1077                .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1078                .clone())
1079        } else {
1080            Ok(self
1081                .portal_store
1082                .get(portal_name)
1083                .ok_or_else(|| {
1084                    PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1085                })?
1086                .clone())
1087        }
1088    }
1089
1090    fn get_statement(
1091        &self,
1092        statement_name: &str,
1093    ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1094        if statement_name.is_empty() {
1095            Ok(self
1096                .unnamed_prepare_statement
1097                .as_ref()
1098                .ok_or_else(|| {
1099                    PsqlError::Uncategorized("unnamed prepare statement not found".into())
1100                })?
1101                .clone())
1102        } else {
1103            Ok(self
1104                .prepare_statement_store
1105                .get(statement_name)
1106                .ok_or_else(|| {
1107                    PsqlError::Uncategorized(
1108                        format!("Prepare statement {} not found", statement_name).into(),
1109                    )
1110                })?
1111                .clone())
1112        }
1113    }
1114}
1115
1116enum PgStreamInner<S> {
1117    /// Used for the intermediate state when converting from unencrypted to ssl stream.
1118    Placeholder,
1119    /// An unencrypted stream.
1120    Unencrypted(S),
1121    /// An ssl stream.
1122    Ssl(SslStream<S>),
1123}
1124
1125/// Trait for a byte stream that can be used for pg protocol.
1126pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1127impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1128
1129/// Wraps a byte stream and read/write pg messages.
1130///
1131/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
1132/// so that it can be used to write messages concurrently without interference.
1133pub struct PgStream<S> {
1134    /// The underlying stream.
1135    stream: Arc<Mutex<PgStreamInner<S>>>,
1136    /// Write into buffer before flush to stream.
1137    write_buf: BytesMut,
1138    read_header: Option<FeMessageHeader>,
1139}
1140
1141impl<S> PgStream<S> {
1142    /// Create a new `PgStream` with the given stream and default write buffer capacity.
1143    pub fn new(stream: S) -> Self {
1144        const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1145
1146        Self {
1147            stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1148            write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1149            read_header: None,
1150        }
1151    }
1152}
1153
1154impl<S> Clone for PgStream<S> {
1155    fn clone(&self) -> Self {
1156        Self {
1157            stream: Arc::clone(&self.stream),
1158            write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1159            read_header: self.read_header.clone(),
1160        }
1161    }
1162}
1163
1164/// At present there is a hard-wired set of parameters for which
1165/// ParameterStatus will be generated: they are:
1166///
1167///  * `server_version`
1168///  * `server_encoding`
1169///  * `client_encoding`
1170///  * `application_name`
1171///  * `is_superuser`
1172///  * `session_authorization`
1173///  * `DateStyle`
1174///  * `IntervalStyle`
1175///  * `TimeZone`
1176///  * `integer_datetimes`
1177///  * `standard_conforming_string`
1178///
1179/// See: <https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC>.
1180#[derive(Debug, Default, Clone)]
1181pub struct ParameterStatus {
1182    pub application_name: Option<String>,
1183}
1184
1185impl<S> PgStream<S>
1186where
1187    S: PgByteStream,
1188{
1189    async fn read_startup(&mut self) -> io::Result<FeMessage> {
1190        let mut stream = self.stream.lock().await;
1191        match &mut *stream {
1192            PgStreamInner::Placeholder => unreachable!(),
1193            PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1194            PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1195        }
1196    }
1197
1198    async fn read_header(&mut self) -> io::Result<()> {
1199        let mut stream = self.stream.lock().await;
1200        match &mut *stream {
1201            PgStreamInner::Placeholder => unreachable!(),
1202            PgStreamInner::Unencrypted(stream) => {
1203                self.read_header = Some(FeMessage::read_header(stream).await?);
1204                Ok(())
1205            }
1206            PgStreamInner::Ssl(ssl_stream) => {
1207                self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1208                Ok(())
1209            }
1210        }
1211    }
1212
1213    async fn read_body(&mut self) -> io::Result<FeMessage> {
1214        let mut stream = self.stream.lock().await;
1215        let header = self
1216            .read_header
1217            .take()
1218            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1219        match &mut *stream {
1220            PgStreamInner::Placeholder => unreachable!(),
1221            PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1222            PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1223        }
1224    }
1225
1226    async fn skip_body(&mut self) -> io::Result<()> {
1227        let mut stream = self.stream.lock().await;
1228        let header = self
1229            .read_header
1230            .take()
1231            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1232        match &mut *stream {
1233            PgStreamInner::Placeholder => unreachable!(),
1234            PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1235            PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1236        }
1237    }
1238
1239    fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1240        self.write_no_flush(&BeMessage::ParameterStatus(
1241            BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1242        ))?;
1243        self.write_no_flush(&BeMessage::ParameterStatus(
1244            BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1245        ))?;
1246        self.write_no_flush(&BeMessage::ParameterStatus(
1247            BeParameterStatusMessage::ServerVersion(PG_VERSION),
1248        ))?;
1249        if let Some(application_name) = &status.application_name {
1250            self.write_no_flush(&BeMessage::ParameterStatus(
1251                BeParameterStatusMessage::ApplicationName(application_name),
1252            ))?;
1253        }
1254        Ok(())
1255    }
1256
1257    pub fn write_no_flush(&mut self, message: &BeMessage<'_>) -> io::Result<()> {
1258        BeMessage::write(&mut self.write_buf, message)
1259    }
1260
1261    async fn write(&mut self, message: &BeMessage<'_>) -> io::Result<()> {
1262        self.write_no_flush(message)?;
1263        self.flush().await?;
1264        Ok(())
1265    }
1266
1267    async fn flush(&mut self) -> io::Result<()> {
1268        let mut stream = self.stream.lock().await;
1269        match &mut *stream {
1270            PgStreamInner::Placeholder => unreachable!(),
1271            PgStreamInner::Unencrypted(stream) => {
1272                stream.write_all(&self.write_buf).await?;
1273                stream.flush().await?;
1274            }
1275            PgStreamInner::Ssl(ssl_stream) => {
1276                ssl_stream.write_all(&self.write_buf).await?;
1277                ssl_stream.flush().await?;
1278            }
1279        }
1280        self.write_buf.clear();
1281        Ok(())
1282    }
1283}
1284
1285impl<S> PgStream<S>
1286where
1287    S: PgByteStream,
1288{
1289    /// Convert the underlying stream to ssl stream based on the given context.
1290    async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1291        let mut stream = self.stream.lock().await;
1292
1293        match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1294            PgStreamInner::Unencrypted(unencrypted_stream) => {
1295                let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1296                let mut ssl_stream =
1297                    tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1298
1299                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1300                    tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1301                    let _ = ssl_stream.shutdown().await;
1302                    return Err(e.into());
1303                }
1304
1305                *stream = PgStreamInner::Ssl(ssl_stream);
1306            }
1307            PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1308            PgStreamInner::Placeholder => unreachable!(),
1309        }
1310
1311        Ok(())
1312    }
1313}
1314
1315fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1316    let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1317
1318    let key_path = &tls_config.key;
1319    let cert_path = &tls_config.cert;
1320
1321    // Build ssl acceptor according to the config.
1322    // Now we set every verify to true.
1323    acceptor
1324        .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1325        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1326    acceptor
1327        .set_ca_file(cert_path)
1328        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1329    acceptor
1330        .set_certificate_chain_file(cert_path)
1331        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1332    let acceptor = acceptor.build();
1333
1334    Ok(acceptor.into_context())
1335}
1336
1337pub mod truncated_fmt {
1338    use std::fmt::*;
1339
1340    struct TruncatedFormatter<'a, 'b> {
1341        remaining: usize,
1342        finished: bool,
1343        f: &'a mut Formatter<'b>,
1344    }
1345    impl Write for TruncatedFormatter<'_, '_> {
1346        fn write_str(&mut self, s: &str) -> Result {
1347            if self.finished {
1348                return Ok(());
1349            }
1350
1351            if self.remaining < s.len() {
1352                let actual = s.floor_char_boundary(self.remaining);
1353                self.f.write_str(&s[0..actual])?;
1354                self.remaining -= actual;
1355                self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1356                self.finished = true; // so that ...(truncated) is printed exactly once
1357            } else {
1358                self.f.write_str(s)?;
1359                self.remaining -= s.len();
1360            }
1361            Ok(())
1362        }
1363    }
1364
1365    pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1366
1367    impl<T> Debug for TruncatedFmt<'_, T>
1368    where
1369        T: Debug,
1370    {
1371        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1372            TruncatedFormatter {
1373                remaining: self.1,
1374                finished: false,
1375                f,
1376            }
1377            .write_fmt(format_args!("{:?}", self.0))
1378        }
1379    }
1380
1381    impl<T> Display for TruncatedFmt<'_, T>
1382    where
1383        T: Display,
1384    {
1385        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1386            TruncatedFormatter {
1387                remaining: self.1,
1388                finished: false,
1389                f,
1390            }
1391            .write_fmt(format_args!("{}", self.0))
1392        }
1393    }
1394
1395    #[cfg(test)]
1396    mod tests {
1397        use super::*;
1398
1399        #[test]
1400        fn test_trunc_utf8() {
1401            assert_eq!(
1402                format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1403                "select '...(truncated,14)",
1404            );
1405        }
1406    }
1407}
1408
1409#[cfg(test)]
1410mod tests {
1411    use std::collections::HashSet;
1412
1413    use super::*;
1414
1415    #[test]
1416    fn test_redact_parsable_sql() {
1417        let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1418        let sql = r"
1419        create source temp (k bigint, v varchar) with (
1420            connector = 'datagen',
1421            v1 = 123,
1422            v2 = 'with',
1423            v3 = false,
1424            v4 = '',
1425        ) FORMAT plain ENCODE json (a='1',b='2')
1426        ";
1427        assert_eq!(
1428            redact_sql(sql, keywords),
1429            "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1430        );
1431    }
1432}