pgwire/
pg_protocol.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::panic::FutureCatchUnwindExt;
32use risingwave_common::util::query_log::*;
33use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
34use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
35use risingwave_sqlparser::parser::Parser;
36use thiserror_ext::AsReport;
37use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
38use tokio::sync::Mutex;
39use tokio_openssl::SslStream;
40use tracing::Instrument;
41
42use crate::error::{PsqlError, PsqlResult};
43use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
44use crate::net::AddressRef;
45use crate::pg_extended::ResultCache;
46use crate::pg_message::{
47    BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
48    FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
49    FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
50};
51use crate::pg_server::{Session, SessionManager, UserAuthenticator};
52use crate::types::Format;
53
54/// Truncates query log if it's longer than `RW_QUERY_LOG_TRUNCATE_LEN`, to avoid log file being too
55/// large.
56static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
57    LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
58        Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
59        _ => {
60            if cfg!(debug_assertions) {
61                65536
62            } else {
63                1024
64            }
65        }
66    });
67
68tokio::task_local! {
69    /// The current session. Concrete type is erased for different session implementations.
70    pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
71}
72
73/// The state machine for each psql connection.
74/// Read pg messages from tcp stream and write results back.
75pub struct PgProtocol<S, SM>
76where
77    SM: SessionManager,
78{
79    /// Used for write/read pg messages.
80    stream: PgStream<S>,
81    /// Current states of pg connection.
82    state: PgProtocolState,
83    /// Whether the connection is terminated.
84    is_terminate: bool,
85
86    session_mgr: Arc<SM>,
87    session: Option<Arc<SM::Session>>,
88
89    result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
90    unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
91    prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
92    unnamed_portal: Option<<SM::Session as Session>::Portal>,
93    portal_store: HashMap<String, <SM::Session as Session>::Portal>,
94    // Used to store the dependency of portal and prepare statement.
95    // When we close a prepare statement, we need to close all the portals that depend on it.
96    statement_portal_dependency: HashMap<String, Vec<String>>,
97
98    // Used for ssl connection.
99    // If None, not expected to build ssl connection (panic).
100    tls_context: Option<SslContext>,
101
102    // Used in extended query protocol. When encounter error in extended query, we need to ignore
103    // the following message util sync message.
104    ignore_util_sync: bool,
105
106    // Client Address
107    peer_addr: AddressRef,
108
109    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
110    message_memory_manager: MessageMemoryManagerRef,
111}
112
113/// Configures TLS encryption for connections.
114#[derive(Debug, Clone)]
115pub struct TlsConfig {
116    /// The path to the TLS certificate.
117    pub cert: String,
118    /// The path to the TLS key.
119    pub key: String,
120}
121
122impl TlsConfig {
123    pub fn new_default() -> Option<Self> {
124        let cert = std::env::var("RW_SSL_CERT").ok()?;
125        let key = std::env::var("RW_SSL_KEY").ok()?;
126        tracing::info!("RW_SSL_CERT={}, RW_SSL_KEY={}", cert, key);
127        Some(Self { cert, key })
128    }
129}
130
131impl<S, SM> Drop for PgProtocol<S, SM>
132where
133    SM: SessionManager,
134{
135    fn drop(&mut self) {
136        if let Some(session) = &self.session {
137            // Clear the session in session manager.
138            self.session_mgr.end_session(session);
139        }
140    }
141}
142
143/// States flow happened from top to down.
144enum PgProtocolState {
145    Startup,
146    Regular,
147}
148
149/// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations).
150///
151/// PG protocol strings are always C strings.
152pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
153    let without_null = if b.last() == Some(&0) {
154        &b[..b.len() - 1]
155    } else {
156        &b[..]
157    };
158    std::str::from_utf8(without_null)
159}
160
161/// Record `sql` in the current tracing span.
162fn record_sql_in_span(
163    sql: &str,
164    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
165) -> String {
166    let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
167        && !keywords.is_empty()
168    {
169        redact_sql(sql, keywords)
170    } else {
171        sql.to_owned()
172    };
173    let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
174    tracing::Span::current().record("sql", tracing::field::display(&truncated));
175    truncated.to_string()
176}
177
178/// Redacts SQL options. Data in DML is not redacted.
179fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
180    match Parser::parse_sql(sql) {
181        Ok(sqls) => sqls
182            .into_iter()
183            .map(|sql| sql.to_redacted_string(keywords.clone()))
184            .join(";"),
185        Err(_) => sql.to_owned(),
186    }
187}
188
189#[derive(Clone)]
190pub struct ConnectionContext {
191    pub tls_config: Option<TlsConfig>,
192    pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
193    pub message_memory_manager: MessageMemoryManagerRef,
194}
195
196impl<S, SM> PgProtocol<S, SM>
197where
198    S: PgByteStream,
199    SM: SessionManager,
200{
201    pub fn new(
202        stream: S,
203        session_mgr: Arc<SM>,
204        peer_addr: AddressRef,
205        context: ConnectionContext,
206    ) -> Self {
207        let ConnectionContext {
208            tls_config,
209            redact_sql_option_keywords,
210            message_memory_manager,
211        } = context;
212        Self {
213            stream: PgStream::new(stream),
214            is_terminate: false,
215            state: PgProtocolState::Startup,
216            session_mgr,
217            session: None,
218            tls_context: tls_config
219                .as_ref()
220                .and_then(|e| build_ssl_ctx_from_config(e).ok()),
221            result_cache: Default::default(),
222            unnamed_prepare_statement: Default::default(),
223            prepare_statement_store: Default::default(),
224            unnamed_portal: Default::default(),
225            portal_store: Default::default(),
226            statement_portal_dependency: Default::default(),
227            ignore_util_sync: false,
228            peer_addr,
229            redact_sql_option_keywords,
230            message_memory_manager,
231        }
232    }
233
234    /// Run the protocol to serve the connection.
235    pub async fn run(&mut self) {
236        let mut notice_fut = None;
237
238        loop {
239            // Once a session is present, create a future to subscribe and send notices asynchronously.
240            if notice_fut.is_none()
241                && let Some(session) = self.session.clone()
242            {
243                let mut stream = self.stream.clone();
244                notice_fut = Some(Box::pin(async move {
245                    loop {
246                        let notice = session.next_notice().await;
247                        if let Err(e) = stream.write(&BeMessage::NoticeResponse(&notice)).await {
248                            tracing::error!(error = %e.as_report(), notice, "failed to send notice");
249                        }
250                    }
251                }));
252            }
253
254            // Read and process messages.
255            let process = std::pin::pin!(async {
256                let (msg, _memory_guard) = match self.read_message().await {
257                    Ok(msg) => msg,
258                    Err(e) => {
259                        tracing::error!(error = %e.as_report(), "error when reading message");
260                        return true; // terminate the connection
261                    }
262                };
263                tracing::trace!(?msg, "received message");
264                self.process(msg).await
265            });
266
267            let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
268                tokio::select! {
269                    _ = notice_fut => unreachable!(),
270                    terminated = process => terminated,
271                }
272            } else {
273                process.await
274            };
275
276            if terminated {
277                break;
278            }
279        }
280    }
281
282    /// Processes one message. Returns true if the connection is terminated.
283    pub async fn process(&mut self, msg: FeMessage) -> bool {
284        self.do_process(msg).await.is_none() || self.is_terminate
285    }
286
287    /// The root tracing span for processing a message. The target of the span is
288    /// [`PGWIRE_ROOT_SPAN_TARGET`].
289    ///
290    /// This is used to provide context for the (slow) query logs and traces.
291    ///
292    /// The span is only effective if there's a current session and the message is
293    /// query-related. Otherwise, `Span::none()` is returned.
294    fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
295        let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
296            return tracing::Span::none();
297        };
298
299        let mode = match msg {
300            FeMessage::Query(_) => "simple query",
301            FeMessage::Parse(_) => "extended query parse",
302            FeMessage::Execute(_) => "extended query execute",
303            _ => return tracing::Span::none(),
304        };
305
306        tracing::info_span!(
307            target: PGWIRE_ROOT_SPAN_TARGET,
308            "handle_query",
309            mode,
310            session_id,
311            sql = tracing::field::Empty, // record SQL later in each `process` call
312        )
313    }
314
315    /// Return type `Option<()>` is essentially a bool, but allows `?` for early return.
316    /// - `None` means to terminate the current connection
317    /// - `Some(())` means to continue processing the next message
318    async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
319        let span = self.root_span_for_msg(&msg);
320        let weak_session = self
321            .session
322            .as_ref()
323            .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
324
325        // Processing the message itself.
326        //
327        // Note: pin the future to avoid stack overflow as we'll wrap it multiple times
328        // in the following code.
329        let fut = Box::pin(self.do_process_inner(msg));
330
331        // Set the current session as the context when processing the message, if exists.
332        let fut = async move {
333            if let Some(session) = weak_session {
334                CURRENT_SESSION.scope(session, fut).await
335            } else {
336                fut.await
337            }
338        };
339
340        // Catch unwind.
341        let fut = async move {
342            AssertUnwindSafe(fut)
343                .rw_catch_unwind()
344                .await
345                .unwrap_or_else(|payload| {
346                    Err(PsqlError::Panic(
347                        panic_message::panic_message(&payload).to_owned(),
348                    ))
349                })
350        };
351
352        // Slow query log.
353        let fut = async move {
354            let period = *SLOW_QUERY_LOG_PERIOD;
355            let mut fut = std::pin::pin!(fut);
356            let mut elapsed = Duration::ZERO;
357
358            // Report the SQL in the log periodically if the query is slow.
359            loop {
360                match tokio::time::timeout(period, &mut fut).await {
361                    Ok(result) => break result,
362                    Err(_) => {
363                        elapsed += period;
364                        tracing::info!(
365                            target: PGWIRE_SLOW_QUERY_LOG,
366                            elapsed = %format_args!("{}ms", elapsed.as_millis()),
367                            "slow query"
368                        );
369                    }
370                }
371            }
372        };
373
374        // Query log.
375        let fut = async move {
376            let start = Instant::now();
377            let result = fut.await;
378            let elapsed = start.elapsed();
379
380            // Always log if an error occurs.
381            // Note: all messages will be processed through this code path, making it the
382            //       only necessary place to log errors.
383            if let Err(error) = &result {
384                tracing::error!(error = %error.as_report(), "error when process message");
385            }
386
387            // Log to optionally-enabled target `PGWIRE_QUERY_LOG`.
388            // Only log if we're currently in a tracing span set in `span_for_msg`.
389            if !tracing::Span::current().is_none() {
390                tracing::info!(
391                    target: PGWIRE_QUERY_LOG,
392                    status = if result.is_ok() { "ok" } else { "err" },
393                    time = %format_args!("{}ms", elapsed.as_millis()),
394                );
395            }
396
397            result
398        };
399
400        // Tracing span.
401        let fut = fut.instrument(span);
402
403        // Execute the future and handle the error.
404        match fut.await {
405            Ok(()) => Some(()),
406            Err(e) => {
407                match e {
408                    PsqlError::IoError(io_err) => {
409                        if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
410                            return None;
411                        }
412                    }
413
414                    PsqlError::SslError(_) => {
415                        // For ssl error, because the stream has already been consumed, so there is
416                        // no way to write more message.
417                        return None;
418                    }
419
420                    PsqlError::StartupError(_) | PsqlError::PasswordError => {
421                        self.stream
422                            .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
423                            .ok()?;
424                        let _ = self.stream.flush().await;
425                        return None;
426                    }
427
428                    PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
429                        self.stream
430                            .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
431                            .ok()?;
432                        self.ready_for_query().ok()?;
433                    }
434
435                    PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
436                        self.stream
437                            .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
438                            .ok()?;
439                        let _ = self.stream.flush().await;
440
441                        // 1. Catching the panic during message processing may leave the session in an
442                        // inconsistent state. We forcefully close the connection (then end the
443                        // session) here for safety.
444                        // 2. Idle in transaction timeout should also close the connection.
445                        return None;
446                    }
447
448                    PsqlError::Uncategorized(_)
449                    | PsqlError::ExtendedPrepareError(_)
450                    | PsqlError::ExtendedExecuteError(_) => {
451                        self.stream
452                            .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
453                            .ok()?;
454                    }
455                }
456                let _ = self.stream.flush().await;
457                Some(())
458            }
459        }
460    }
461
462    async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
463        // Ignore util sync message.
464        if self.ignore_util_sync {
465            if let FeMessage::Sync = msg {
466            } else {
467                tracing::trace!("ignore message {:?} until sync.", msg);
468                return Ok(());
469            }
470        }
471
472        match msg {
473            FeMessage::Gss => self.process_gss_msg().await?,
474            FeMessage::Ssl => self.process_ssl_msg().await?,
475            FeMessage::Startup(msg) => self.process_startup_msg(msg)?,
476            FeMessage::Password(msg) => self.process_password_msg(msg).await?,
477            FeMessage::Query(query_msg) => {
478                let sql = Arc::from(query_msg.get_sql()?);
479                // The process_query_msg can be slow. Release potential large FeQueryMessage early.
480                drop(query_msg);
481                self.process_query_msg(sql).await?
482            }
483            FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
484            FeMessage::Terminate => self.process_terminate(),
485            FeMessage::Parse(m) => {
486                if let Err(err) = self.process_parse_msg(m).await {
487                    self.ignore_util_sync = true;
488                    return Err(err);
489                }
490            }
491            FeMessage::Bind(m) => {
492                if let Err(err) = self.process_bind_msg(m) {
493                    self.ignore_util_sync = true;
494                    return Err(err);
495                }
496            }
497            FeMessage::Execute(m) => {
498                if let Err(err) = self.process_execute_msg(m).await {
499                    self.ignore_util_sync = true;
500                    return Err(err);
501                }
502            }
503            FeMessage::Describe(m) => {
504                if let Err(err) = self.process_describe_msg(m) {
505                    self.ignore_util_sync = true;
506                    return Err(err);
507                }
508            }
509            FeMessage::Sync => {
510                self.ignore_util_sync = false;
511                self.ready_for_query()?
512            }
513            FeMessage::Close(m) => {
514                if let Err(err) = self.process_close_msg(m) {
515                    self.ignore_util_sync = true;
516                    return Err(err);
517                }
518            }
519            FeMessage::Flush => {
520                if let Err(err) = self.stream.flush().await {
521                    self.ignore_util_sync = true;
522                    return Err(err.into());
523                }
524            }
525            FeMessage::HealthCheck => self.process_health_check(),
526            FeMessage::ServerThrottle(reason) => match reason {
527                ServerThrottleReason::TooLargeMessage => {
528                    return Err(PsqlError::ServerThrottle(anyhow::anyhow!("max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit", self.message_memory_manager.max_filter_bytes).into()));
529                }
530                ServerThrottleReason::TooManyMemoryUsage => {
531                    return Err(PsqlError::ServerThrottle(anyhow::anyhow!("max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit", self.message_memory_manager.max_running_bytes).into()));
532                }
533            },
534        }
535        self.stream.flush().await?;
536        Ok(())
537    }
538
539    pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
540        match self.state {
541            PgProtocolState::Startup => self
542                .stream
543                .read_startup()
544                .await
545                .map(|message: FeMessage| (message, None)),
546            PgProtocolState::Regular => {
547                self.stream.read_header().await?;
548                let guard = if let Some(ref header) = self.stream.read_header {
549                    let payload_len = std::cmp::max(header.payload_len, 0) as u64;
550                    let (reason, guard) = self.message_memory_manager.add(payload_len);
551                    if let Some(reason) = reason {
552                        // Release the memory ASAP.
553                        drop(guard);
554                        self.stream.skip_body().await?;
555                        return Ok((FeMessage::ServerThrottle(reason), None));
556                    }
557                    guard
558                } else {
559                    None
560                };
561                let message = self.stream.read_body().await?;
562                Ok((message, guard))
563            }
564        }
565    }
566
567    /// Writes a `ReadyForQuery` message to the client without flushing.
568    fn ready_for_query(&mut self) -> io::Result<()> {
569        self.stream.write_no_flush(&BeMessage::ReadyForQuery(
570            self.session
571                .as_ref()
572                .map(|s| s.transaction_status())
573                .unwrap_or(TransactionStatus::Idle),
574        ))
575    }
576
577    async fn process_gss_msg(&mut self) -> PsqlResult<()> {
578        // We don't support GSSAPI, so we just say no gracefully.
579        self.stream.write(&BeMessage::EncryptionResponseNo).await?;
580        Ok(())
581    }
582
583    async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
584        if let Some(context) = self.tls_context.as_ref() {
585            // If got and ssl context, say yes for ssl connection.
586            // Construct ssl stream and replace with current one.
587            self.stream.write(&BeMessage::EncryptionResponseSsl).await?;
588            self.stream.upgrade_to_ssl(context).await?;
589        } else {
590            // If no, say no for encryption.
591            self.stream.write(&BeMessage::EncryptionResponseNo).await?;
592        }
593
594        Ok(())
595    }
596
597    fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
598        let db_name = msg
599            .config
600            .get("database")
601            .cloned()
602            .unwrap_or_else(|| "dev".to_owned());
603        let user_name = msg
604            .config
605            .get("user")
606            .cloned()
607            .unwrap_or_else(|| "root".to_owned());
608
609        let session = self
610            .session_mgr
611            .connect(&db_name, &user_name, self.peer_addr.clone())
612            .map_err(PsqlError::StartupError)?;
613
614        let application_name = msg.config.get("application_name");
615        if let Some(application_name) = application_name {
616            session
617                .set_config("application_name", application_name.clone())
618                .map_err(PsqlError::StartupError)?;
619        }
620
621        match session.user_authenticator() {
622            UserAuthenticator::None => {
623                self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
624
625                // Cancel request need this for identify and verification. According to postgres
626                // doc, it should be written to buffer after receive AuthenticationOk.
627                self.stream
628                    .write_no_flush(&BeMessage::BackendKeyData(session.id()))?;
629
630                self.stream
631                    .write_parameter_status_msg_no_flush(&ParameterStatus {
632                        application_name: application_name.cloned(),
633                    })?;
634                self.ready_for_query()?;
635            }
636            UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth(_) => {
637                self.stream
638                    .write_no_flush(&BeMessage::AuthenticationCleartextPassword)?;
639            }
640            UserAuthenticator::Md5WithSalt { salt, .. } => {
641                self.stream
642                    .write_no_flush(&BeMessage::AuthenticationMd5Password(salt))?;
643            }
644        }
645
646        self.session = Some(session);
647        self.state = PgProtocolState::Regular;
648        Ok(())
649    }
650
651    async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
652        let authenticator = self.session.as_ref().unwrap().user_authenticator();
653        authenticator.authenticate(&msg.password).await?;
654        self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
655        self.stream
656            .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
657        self.ready_for_query()?;
658        self.state = PgProtocolState::Regular;
659        Ok(())
660    }
661
662    fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
663        let session_id = (m.target_process_id, m.target_secret_key);
664        tracing::trace!("cancel query in session: {:?}", session_id);
665        self.session_mgr.cancel_queries_in_session(session_id);
666        self.session_mgr.cancel_creating_jobs_in_session(session_id);
667        self.stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
668        Ok(())
669    }
670
671    async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
672        let truncated_sql = record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
673        let session = self.session.clone().unwrap();
674
675        session.check_idle_in_transaction_timeout()?;
676        // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
677        let _exec_context_guard = session.init_exec_context(truncated_sql.into());
678        self.inner_process_query_msg(sql, session.clone()).await
679    }
680
681    async fn inner_process_query_msg(
682        &mut self,
683        sql: Arc<str>,
684        session: Arc<SM::Session>,
685    ) -> PsqlResult<()> {
686        // Parse sql.
687        let stmts =
688            Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
689        // The following inner_process_query_msg_one_stmt can be slow. Release potential large String early.
690        drop(sql);
691        if stmts.is_empty() {
692            self.stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
693        }
694
695        // Execute multiple statements in simple query. KISS later.
696        for stmt in stmts {
697            self.inner_process_query_msg_one_stmt(stmt, session.clone())
698                .await?;
699        }
700        // Put this line inside the for loop above will lead to unfinished/stuck regress test...Not
701        // sure the reason.
702        self.ready_for_query()?;
703        Ok(())
704    }
705
706    async fn inner_process_query_msg_one_stmt(
707        &mut self,
708        stmt: Statement,
709        session: Arc<SM::Session>,
710    ) -> PsqlResult<()> {
711        let session = session.clone();
712
713        // execute query
714        let res = session.clone().run_one_query(stmt, Format::Text).await;
715
716        // Take all remaining notices (if any) and send them before `CommandComplete`.
717        while let Some(notice) = session.next_notice().now_or_never() {
718            self.stream
719                .write_no_flush(&BeMessage::NoticeResponse(&notice))?;
720        }
721
722        let mut res = res.map_err(PsqlError::SimpleQueryError)?;
723
724        for notice in res.notices() {
725            self.stream
726                .write_no_flush(&BeMessage::NoticeResponse(notice))?;
727        }
728
729        let status = res.status();
730        if let Some(ref application_name) = status.application_name {
731            self.stream.write_no_flush(&BeMessage::ParameterStatus(
732                BeParameterStatusMessage::ApplicationName(application_name),
733            ))?;
734        }
735
736        if res.is_query() {
737            self.stream
738                .write_no_flush(&BeMessage::RowDescription(&res.row_desc()))?;
739
740            let mut rows_cnt = 0;
741
742            while let Some(row_set) = res.values_stream().next().await {
743                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
744                for row in row_set {
745                    self.stream.write_no_flush(&BeMessage::DataRow(&row))?;
746                    rows_cnt += 1;
747                }
748            }
749
750            // Run the callback before sending the `CommandComplete` message.
751            res.run_callback().await?;
752
753            self.stream
754                .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
755                    stmt_type: res.stmt_type(),
756                    rows_cnt,
757                }))?;
758        } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
759            let first_row_set = res.values_stream().next().await;
760            let first_row_set = match first_row_set {
761                None => {
762                    return Err(PsqlError::Uncategorized(
763                        anyhow::anyhow!("no affected rows in output").into(),
764                    ));
765                }
766                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
767            };
768            let affected_rows_str = first_row_set[0].values()[0]
769                .as_ref()
770                .expect("compute node should return affected rows in output");
771
772            assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
773            let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
774                .unwrap()
775                .parse()
776                .unwrap_or_default();
777
778            // Run the callback before sending the `CommandComplete` message.
779            res.run_callback().await?;
780
781            self.stream
782                .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
783                    stmt_type: res.stmt_type(),
784                    rows_cnt: affected_rows_cnt,
785                }))?;
786        } else {
787            // Run the callback before sending the `CommandComplete` message.
788            res.run_callback().await?;
789
790            self.stream
791                .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
792                    stmt_type: res.stmt_type(),
793                    rows_cnt: 0,
794                }))?;
795        }
796
797        Ok(())
798    }
799
800    fn process_terminate(&mut self) {
801        self.is_terminate = true;
802    }
803
804    fn process_health_check(&mut self) {
805        tracing::debug!("health check");
806        self.is_terminate = true;
807    }
808
809    async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
810        let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
811        record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
812        let session = self.session.clone().unwrap();
813        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
814        let type_ids = std::mem::take(&mut msg.type_ids);
815        // The inner_process_parse_msg can be slow. Release potential large FeParseMessage early.
816        drop(msg);
817        self.inner_process_parse_msg(session, sql, statement_name, type_ids)
818            .await?;
819        Ok(())
820    }
821
822    async fn inner_process_parse_msg(
823        &mut self,
824        session: Arc<SM::Session>,
825        sql: Arc<str>,
826        statement_name: String,
827        type_ids: Vec<i32>,
828    ) -> PsqlResult<()> {
829        if statement_name.is_empty() {
830            // Remove the unnamed prepare statement first, in case the unsupported sql binds a wrong
831            // prepare statement.
832            self.unnamed_prepare_statement.take();
833        } else if self.prepare_statement_store.contains_key(&statement_name) {
834            return Err(PsqlError::ExtendedPrepareError(
835                "Duplicated statement name".into(),
836            ));
837        }
838
839        let stmt = {
840            let stmts = Parser::parse_sql(&sql)
841                .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
842            drop(sql);
843            if stmts.len() > 1 {
844                return Err(PsqlError::ExtendedPrepareError(
845                    "Only one statement is allowed in extended query mode".into(),
846                ));
847            }
848
849            stmts.into_iter().next()
850        };
851
852        let param_types: Vec<Option<DataType>> = type_ids
853            .iter()
854            .map(|&id| {
855                // 0 means unspecified type
856                // ref: https://www.postgresql.org/docs/15/protocol-message-formats.html#:~:text=Placing%20a%20zero%20here%20is%20equivalent%20to%20leaving%20the%20type%20unspecified.
857                if id == 0 {
858                    Ok(None)
859                } else {
860                    DataType::from_oid(id)
861                        .map(Some)
862                        .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
863                }
864            })
865            .try_collect()?;
866
867        let prepare_statement = session
868            .parse(stmt, param_types)
869            .await
870            .map_err(PsqlError::ExtendedPrepareError)?;
871
872        if statement_name.is_empty() {
873            self.unnamed_prepare_statement.replace(prepare_statement);
874        } else {
875            self.prepare_statement_store
876                .insert(statement_name.clone(), prepare_statement);
877        }
878
879        self.statement_portal_dependency
880            .entry(statement_name)
881            .or_default()
882            .clear();
883
884        self.stream.write_no_flush(&BeMessage::ParseComplete)?;
885        Ok(())
886    }
887
888    fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
889        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
890        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
891        let session = self.session.clone().unwrap();
892
893        if self.portal_store.contains_key(&portal_name) {
894            return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
895        }
896
897        let prepare_statement = self.get_statement(&statement_name)?;
898
899        let result_formats = msg
900            .result_format_codes
901            .iter()
902            .map(|&format_code| Format::from_i16(format_code))
903            .try_collect()?;
904        let param_formats = msg
905            .param_format_codes
906            .iter()
907            .map(|&format_code| Format::from_i16(format_code))
908            .try_collect()?;
909
910        let portal = session
911            .bind(prepare_statement, msg.params, param_formats, result_formats)
912            .map_err(PsqlError::Uncategorized)?;
913
914        if portal_name.is_empty() {
915            self.result_cache.remove(&portal_name);
916            self.unnamed_portal.replace(portal);
917        } else {
918            assert!(
919                !self.result_cache.contains_key(&portal_name),
920                "Named portal never can be overridden."
921            );
922            self.portal_store.insert(portal_name.clone(), portal);
923        }
924
925        self.statement_portal_dependency
926            .get_mut(&statement_name)
927            .unwrap()
928            .push(portal_name);
929
930        self.stream.write_no_flush(&BeMessage::BindComplete)?;
931        Ok(())
932    }
933
934    async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
935        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
936        let row_max = msg.max_rows as usize;
937        drop(msg);
938        let session = self.session.clone().unwrap();
939
940        match self.result_cache.remove(&portal_name) {
941            Some(mut result_cache) => {
942                assert!(self.portal_store.contains_key(&portal_name));
943
944                let is_cosume_completed =
945                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
946
947                if !is_cosume_completed {
948                    self.result_cache.insert(portal_name, result_cache);
949                }
950            }
951            _ => {
952                let portal = self.get_portal(&portal_name)?;
953                let sql = format!("{}", portal);
954                let truncated_sql =
955                    record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
956                drop(sql);
957
958                session.check_idle_in_transaction_timeout()?;
959                // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
960                let _exec_context_guard = session.init_exec_context(truncated_sql.into());
961                let result = session.clone().execute(portal).await;
962
963                let pg_response = result.map_err(PsqlError::ExtendedExecuteError)?;
964                let mut result_cache = ResultCache::new(pg_response);
965                let is_consume_completed =
966                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
967                if !is_consume_completed {
968                    self.result_cache.insert(portal_name, result_cache);
969                }
970            }
971        }
972
973        Ok(())
974    }
975
976    fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
977        let name = cstr_to_str(&msg.name).unwrap().to_owned();
978        let session = self.session.clone().unwrap();
979        //  b'S' => Statement
980        //  b'P' => Portal
981
982        assert!(msg.kind == b'S' || msg.kind == b'P');
983        if msg.kind == b'S' {
984            let prepare_statement = self.get_statement(&name)?;
985
986            let (param_types, row_descriptions) = self
987                .session
988                .clone()
989                .unwrap()
990                .describe_statement(prepare_statement)
991                .map_err(PsqlError::Uncategorized)?;
992            self.stream
993                .write_no_flush(&BeMessage::ParameterDescription(
994                    &param_types.iter().map(|t| t.to_oid()).collect_vec(),
995                ))?;
996
997            if row_descriptions.is_empty() {
998                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
999                // return NoData message if the statement is not a query.
1000                self.stream.write_no_flush(&BeMessage::NoData)?;
1001            } else {
1002                self.stream
1003                    .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?;
1004            }
1005        } else if msg.kind == b'P' {
1006            let portal = self.get_portal(&name)?;
1007
1008            let row_descriptions = session
1009                .describe_portal(portal)
1010                .map_err(PsqlError::Uncategorized)?;
1011
1012            if row_descriptions.is_empty() {
1013                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1014                // return NoData message if the statement is not a query.
1015                self.stream.write_no_flush(&BeMessage::NoData)?;
1016            } else {
1017                self.stream
1018                    .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?;
1019            }
1020        }
1021        Ok(())
1022    }
1023
1024    fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1025        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1026        assert!(msg.kind == b'S' || msg.kind == b'P');
1027        if msg.kind == b'S' {
1028            if name.is_empty() {
1029                self.unnamed_prepare_statement = None;
1030            } else {
1031                self.prepare_statement_store.remove(&name);
1032            }
1033            for portal_name in self
1034                .statement_portal_dependency
1035                .remove(&name)
1036                .unwrap_or_default()
1037            {
1038                self.remove_portal(&portal_name);
1039            }
1040        } else if msg.kind == b'P' {
1041            self.remove_portal(&name);
1042        }
1043        self.stream.write_no_flush(&BeMessage::CloseComplete)?;
1044        Ok(())
1045    }
1046
1047    fn remove_portal(&mut self, portal_name: &str) {
1048        if portal_name.is_empty() {
1049            self.unnamed_portal = None;
1050        } else {
1051            self.portal_store.remove(portal_name);
1052        }
1053        self.result_cache.remove(portal_name);
1054    }
1055
1056    fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1057        if portal_name.is_empty() {
1058            Ok(self
1059                .unnamed_portal
1060                .as_ref()
1061                .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1062                .clone())
1063        } else {
1064            Ok(self
1065                .portal_store
1066                .get(portal_name)
1067                .ok_or_else(|| {
1068                    PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1069                })?
1070                .clone())
1071        }
1072    }
1073
1074    fn get_statement(
1075        &self,
1076        statement_name: &str,
1077    ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1078        if statement_name.is_empty() {
1079            Ok(self
1080                .unnamed_prepare_statement
1081                .as_ref()
1082                .ok_or_else(|| {
1083                    PsqlError::Uncategorized("unnamed prepare statement not found".into())
1084                })?
1085                .clone())
1086        } else {
1087            Ok(self
1088                .prepare_statement_store
1089                .get(statement_name)
1090                .ok_or_else(|| {
1091                    PsqlError::Uncategorized(
1092                        format!("Prepare statement {} not found", statement_name).into(),
1093                    )
1094                })?
1095                .clone())
1096        }
1097    }
1098}
1099
1100enum PgStreamInner<S> {
1101    /// Used for the intermediate state when converting from unencrypted to ssl stream.
1102    Placeholder,
1103    /// An unencrypted stream.
1104    Unencrypted(S),
1105    /// An ssl stream.
1106    Ssl(SslStream<S>),
1107}
1108
1109/// Trait for a byte stream that can be used for pg protocol.
1110pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1111impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1112
1113/// Wraps a byte stream and read/write pg messages.
1114///
1115/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
1116/// so that it can be used to write messages concurrently without interference.
1117pub struct PgStream<S> {
1118    /// The underlying stream.
1119    stream: Arc<Mutex<PgStreamInner<S>>>,
1120    /// Write into buffer before flush to stream.
1121    write_buf: BytesMut,
1122    read_header: Option<FeMessageHeader>,
1123}
1124
1125impl<S> PgStream<S> {
1126    /// Create a new `PgStream` with the given stream and default write buffer capacity.
1127    pub fn new(stream: S) -> Self {
1128        const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1129
1130        Self {
1131            stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1132            write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1133            read_header: None,
1134        }
1135    }
1136}
1137
1138impl<S> Clone for PgStream<S> {
1139    fn clone(&self) -> Self {
1140        Self {
1141            stream: Arc::clone(&self.stream),
1142            write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1143            read_header: self.read_header.clone(),
1144        }
1145    }
1146}
1147
1148/// At present there is a hard-wired set of parameters for which
1149/// ParameterStatus will be generated: they are:
1150///
1151///  * `server_version`
1152///  * `server_encoding`
1153///  * `client_encoding`
1154///  * `application_name`
1155///  * `is_superuser`
1156///  * `session_authorization`
1157///  * `DateStyle`
1158///  * `IntervalStyle`
1159///  * `TimeZone`
1160///  * `integer_datetimes`
1161///  * `standard_conforming_string`
1162///
1163/// See: <https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC>.
1164#[derive(Debug, Default, Clone)]
1165pub struct ParameterStatus {
1166    pub application_name: Option<String>,
1167}
1168
1169impl<S> PgStream<S>
1170where
1171    S: PgByteStream,
1172{
1173    async fn read_startup(&mut self) -> io::Result<FeMessage> {
1174        let mut stream = self.stream.lock().await;
1175        match &mut *stream {
1176            PgStreamInner::Placeholder => unreachable!(),
1177            PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1178            PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1179        }
1180    }
1181
1182    async fn read_header(&mut self) -> io::Result<()> {
1183        let mut stream = self.stream.lock().await;
1184        match &mut *stream {
1185            PgStreamInner::Placeholder => unreachable!(),
1186            PgStreamInner::Unencrypted(stream) => {
1187                self.read_header = Some(FeMessage::read_header(stream).await?);
1188                Ok(())
1189            }
1190            PgStreamInner::Ssl(ssl_stream) => {
1191                self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1192                Ok(())
1193            }
1194        }
1195    }
1196
1197    async fn read_body(&mut self) -> io::Result<FeMessage> {
1198        let mut stream = self.stream.lock().await;
1199        let header = self
1200            .read_header
1201            .take()
1202            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1203        match &mut *stream {
1204            PgStreamInner::Placeholder => unreachable!(),
1205            PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1206            PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1207        }
1208    }
1209
1210    async fn skip_body(&mut self) -> io::Result<()> {
1211        let mut stream = self.stream.lock().await;
1212        let header = self
1213            .read_header
1214            .take()
1215            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1216        match &mut *stream {
1217            PgStreamInner::Placeholder => unreachable!(),
1218            PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1219            PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1220        }
1221    }
1222
1223    fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1224        self.write_no_flush(&BeMessage::ParameterStatus(
1225            BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1226        ))?;
1227        self.write_no_flush(&BeMessage::ParameterStatus(
1228            BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1229        ))?;
1230        self.write_no_flush(&BeMessage::ParameterStatus(
1231            BeParameterStatusMessage::ServerVersion(PG_VERSION),
1232        ))?;
1233        if let Some(application_name) = &status.application_name {
1234            self.write_no_flush(&BeMessage::ParameterStatus(
1235                BeParameterStatusMessage::ApplicationName(application_name),
1236            ))?;
1237        }
1238        Ok(())
1239    }
1240
1241    pub fn write_no_flush(&mut self, message: &BeMessage<'_>) -> io::Result<()> {
1242        BeMessage::write(&mut self.write_buf, message)
1243    }
1244
1245    async fn write(&mut self, message: &BeMessage<'_>) -> io::Result<()> {
1246        self.write_no_flush(message)?;
1247        self.flush().await?;
1248        Ok(())
1249    }
1250
1251    async fn flush(&mut self) -> io::Result<()> {
1252        let mut stream = self.stream.lock().await;
1253        match &mut *stream {
1254            PgStreamInner::Placeholder => unreachable!(),
1255            PgStreamInner::Unencrypted(stream) => {
1256                stream.write_all(&self.write_buf).await?;
1257                stream.flush().await?;
1258            }
1259            PgStreamInner::Ssl(ssl_stream) => {
1260                ssl_stream.write_all(&self.write_buf).await?;
1261                ssl_stream.flush().await?;
1262            }
1263        }
1264        self.write_buf.clear();
1265        Ok(())
1266    }
1267}
1268
1269impl<S> PgStream<S>
1270where
1271    S: PgByteStream,
1272{
1273    /// Convert the underlying stream to ssl stream based on the given context.
1274    async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1275        let mut stream = self.stream.lock().await;
1276
1277        match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1278            PgStreamInner::Unencrypted(unencrypted_stream) => {
1279                let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1280                let mut ssl_stream =
1281                    tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1282
1283                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1284                    tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1285                    let _ = ssl_stream.shutdown().await;
1286                    return Err(e.into());
1287                }
1288
1289                *stream = PgStreamInner::Ssl(ssl_stream);
1290            }
1291            PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1292            PgStreamInner::Placeholder => unreachable!(),
1293        }
1294
1295        Ok(())
1296    }
1297}
1298
1299fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1300    let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1301
1302    let key_path = &tls_config.key;
1303    let cert_path = &tls_config.cert;
1304
1305    // Build ssl acceptor according to the config.
1306    // Now we set every verify to true.
1307    acceptor
1308        .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1309        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1310    acceptor
1311        .set_ca_file(cert_path)
1312        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1313    acceptor
1314        .set_certificate_chain_file(cert_path)
1315        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1316    let acceptor = acceptor.build();
1317
1318    Ok(acceptor.into_context())
1319}
1320
1321pub mod truncated_fmt {
1322    use std::fmt::*;
1323
1324    struct TruncatedFormatter<'a, 'b> {
1325        remaining: usize,
1326        finished: bool,
1327        f: &'a mut Formatter<'b>,
1328    }
1329    impl Write for TruncatedFormatter<'_, '_> {
1330        fn write_str(&mut self, s: &str) -> Result {
1331            if self.finished {
1332                return Ok(());
1333            }
1334
1335            if self.remaining < s.len() {
1336                let actual = s.floor_char_boundary(self.remaining);
1337                self.f.write_str(&s[0..actual])?;
1338                self.remaining -= actual;
1339                self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1340                self.finished = true; // so that ...(truncated) is printed exactly once
1341            } else {
1342                self.f.write_str(s)?;
1343                self.remaining -= s.len();
1344            }
1345            Ok(())
1346        }
1347    }
1348
1349    pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1350
1351    impl<T> Debug for TruncatedFmt<'_, T>
1352    where
1353        T: Debug,
1354    {
1355        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1356            TruncatedFormatter {
1357                remaining: self.1,
1358                finished: false,
1359                f,
1360            }
1361            .write_fmt(format_args!("{:?}", self.0))
1362        }
1363    }
1364
1365    impl<T> Display for TruncatedFmt<'_, T>
1366    where
1367        T: Display,
1368    {
1369        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1370            TruncatedFormatter {
1371                remaining: self.1,
1372                finished: false,
1373                f,
1374            }
1375            .write_fmt(format_args!("{}", self.0))
1376        }
1377    }
1378
1379    #[cfg(test)]
1380    mod tests {
1381        use super::*;
1382
1383        #[test]
1384        fn test_trunc_utf8() {
1385            assert_eq!(
1386                format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1387                "select '...(truncated,14)",
1388            );
1389        }
1390    }
1391}
1392
1393#[cfg(test)]
1394mod tests {
1395    use std::collections::HashSet;
1396
1397    use super::*;
1398
1399    #[test]
1400    fn test_redact_parsable_sql() {
1401        let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1402        let sql = r"
1403        create source temp (k bigint, v varchar) with (
1404            connector = 'datagen',
1405            v1 = 123,
1406            v2 = 'with',
1407            v3 = false,
1408            v4 = '',
1409        ) FORMAT plain ENCODE json (a='1',b='2')
1410        ";
1411        assert_eq!(
1412            redact_sql(sql, keywords),
1413            "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = '[REDACTED]', v3 = false, v4 = '[REDACTED]') FORMAT PLAIN ENCODE JSON (a = '1', b = '[REDACTED]')"
1414        );
1415    }
1416}