1use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::error_or_notice::Severity;
46use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
47use crate::net::AddressRef;
48use crate::pg_extended::ResultCache;
49use crate::pg_message::{
50 BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
51 FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
52 FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
53};
54use crate::pg_server::{Session, SessionManager, UserAuthenticator};
55use crate::types::Format;
56
57static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
60 LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
61 Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
62 _ => 65536,
63 });
64
65tokio::task_local! {
66 pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
68}
69
70pub struct PgProtocol<S, SM>
73where
74 SM: SessionManager,
75{
76 stream: PgStream<S>,
78 state: PgProtocolState,
80 is_terminate: bool,
82
83 session_mgr: Arc<SM>,
84 session: Option<Arc<SM::Session>>,
85
86 result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
87 unnamed_prepare_statement:
88 Option<PreparedStatementData<<SM::Session as Session>::PreparedStatement>>,
89 prepare_statement_store:
90 HashMap<String, PreparedStatementData<<SM::Session as Session>::PreparedStatement>>,
91 unnamed_portal: Option<PortalData<<SM::Session as Session>::Portal>>,
92 portal_store: HashMap<String, PortalData<<SM::Session as Session>::Portal>>,
93 statement_portal_dependency: HashMap<String, Vec<String>>,
96
97 tls_context: Option<SslContext>,
100
101 tls_config: Option<TlsConfig>,
103
104 ignore_util_sync: bool,
107
108 peer_addr: AddressRef,
110
111 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
112 message_memory_manager: MessageMemoryManagerRef,
113}
114
115#[derive(Debug, Clone)]
117pub struct TlsConfig {
118 pub cert: String,
120 pub key: String,
122 pub enforce_ssl: bool,
124}
125
126impl TlsConfig {
127 pub fn new_default() -> anyhow::Result<Option<Self>> {
128 let cert = std::env::var("RW_SSL_CERT").ok();
129 let key = std::env::var("RW_SSL_KEY").ok();
130 let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
131
132 if cert.is_some() ^ key.is_some() {
133 return Err(anyhow::anyhow!(
134 "RW_SSL_CERT and RW_SSL_KEY must be set together"
135 ));
136 }
137
138 if enforce_ssl && cert.is_none() {
139 return Err(anyhow::anyhow!(
140 "RW_SSL_ENFORCE requires RW_SSL_CERT and RW_SSL_KEY to be set"
141 ));
142 }
143
144 let (Some(cert), Some(key)) = (cert, key) else {
145 return Ok(None);
146 };
147
148 tracing::info!(
149 "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
150 cert,
151 key,
152 enforce_ssl
153 );
154 Ok(Some(Self {
155 cert,
156 key,
157 enforce_ssl,
158 }))
159 }
160}
161
162impl<S, SM> Drop for PgProtocol<S, SM>
163where
164 SM: SessionManager,
165{
166 fn drop(&mut self) {
167 if let Some(session) = &self.session {
168 self.session_mgr.end_session(session);
170 }
171 }
172}
173
174enum PgProtocolState {
176 Startup,
177 Regular,
178}
179
180#[derive(Clone)]
181struct PreparedStatementData<S> {
182 statement: S,
183 sql: Arc<str>,
184}
185
186#[derive(Clone)]
187struct PortalData<P> {
188 portal: P,
189 sql: Arc<str>,
190}
191
192pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
196 let without_null = if b.last() == Some(&0) {
197 &b[..b.len() - 1]
198 } else {
199 &b[..]
200 };
201 std::str::from_utf8(without_null)
202}
203
204fn get_redacted_and_truncated_sql(
205 sql: &str,
206 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
207) -> String {
208 let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
209 && !keywords.is_empty()
210 {
211 redact_sql(sql, keywords)
212 } else {
213 sql.to_owned()
214 };
215 let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
216 truncated.to_string()
217}
218
219fn record_sql_in_span(
221 sql: &str,
222 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
223 span: &mut tracing::Span,
224) {
225 let redacted_and_truncated_sql =
226 get_redacted_and_truncated_sql(sql, redact_sql_option_keywords);
227 span.record("sql", tracing::field::display(&redacted_and_truncated_sql));
228}
229
230fn record_user_in_span(user: &str, span: &mut tracing::Span) {
231 span.record("user", tracing::field::display(user));
232}
233
234fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
236 match Parser::parse_sql(sql) {
237 Ok(sqls) => sqls
238 .into_iter()
239 .map(|sql| sql.to_redacted_string(keywords.clone()))
240 .join(";"),
241 Err(_) => sql.to_owned(),
242 }
243}
244
245#[derive(Clone)]
246pub struct ConnectionContext {
247 pub tls_config: Option<TlsConfig>,
248 pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
249 pub message_memory_manager: MessageMemoryManagerRef,
250}
251
252impl<S, SM> PgProtocol<S, SM>
253where
254 S: PgByteStream,
255 SM: SessionManager,
256{
257 pub fn new(
258 stream: S,
259 session_mgr: Arc<SM>,
260 peer_addr: AddressRef,
261 context: ConnectionContext,
262 ) -> Self {
263 let ConnectionContext {
264 tls_config,
265 redact_sql_option_keywords,
266 message_memory_manager,
267 } = context;
268 Self {
269 stream: PgStream::new(stream),
270 is_terminate: false,
271 state: PgProtocolState::Startup,
272 session_mgr,
273 session: None,
274 tls_context: tls_config
275 .as_ref()
276 .and_then(|e| build_ssl_ctx_from_config(e).ok()),
277 tls_config,
278 result_cache: Default::default(),
279 unnamed_prepare_statement: Default::default(),
280 prepare_statement_store: Default::default(),
281 unnamed_portal: Default::default(),
282 portal_store: Default::default(),
283 statement_portal_dependency: Default::default(),
284 ignore_util_sync: false,
285 peer_addr,
286 redact_sql_option_keywords,
287 message_memory_manager,
288 }
289 }
290
291 pub async fn run(&mut self) {
293 let mut notice_fut = None;
294
295 loop {
296 if notice_fut.is_none()
298 && let Some(session) = self.session.clone()
299 {
300 let mut stream = self.stream.clone();
301 notice_fut = Some(Box::pin(async move {
302 loop {
303 let notice = session.next_notice().await;
304 if let Err(e) = stream.write(BeMessage::NoticeResponse(¬ice)).await {
305 tracing::error!(error = %e.as_report(), notice, "failed to send notice");
306 }
307 }
308 }));
309 }
310
311 let process = std::pin::pin!(async {
313 let (msg, _memory_guard) = match self.read_message().await {
314 Ok(msg) => msg,
315 Err(e) => {
316 tracing::error!(error = %e.as_report(), "error when reading message");
317 return true; }
319 };
320 tracing::trace!(?msg, "received message");
321 self.process(msg).await
322 });
323
324 let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
325 tokio::select! {
326 _ = notice_fut => unreachable!(),
327 terminated = process => terminated,
328 }
329 } else {
330 process.await
331 };
332
333 if terminated {
334 break;
335 }
336 }
337 }
338
339 pub async fn process(&mut self, msg: FeMessage) -> bool {
341 self.do_process(msg).await.is_none() || self.is_terminate
342 }
343
344 fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
352 let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
353 return tracing::Span::none();
354 };
355
356 let mode = match msg {
357 FeMessage::Query(_) => "simple query",
358 FeMessage::Parse(_) => "extended query parse",
359 FeMessage::Execute(_) => "extended query execute",
360 _ => return tracing::Span::none(),
361 };
362
363 let mut span = tracing::info_span!(
364 target: PGWIRE_ROOT_SPAN_TARGET,
365 "handle_query",
366 mode,
367 session_id,
368 sql = tracing::field::Empty,
369 user = tracing::field::Empty,
370 );
371 match msg {
372 FeMessage::Execute(execute_msg) => {
373 if let Ok(portal_name) = cstr_to_str(&execute_msg.portal_name)
374 && let Ok(sql) = self.get_portal_sql(portal_name)
375 {
376 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone(), &mut span);
377 }
378 }
379 _ => {
380 if let Ok(sql) = msg.get_sql()
381 && let Some(sql) = sql
382 {
383 record_sql_in_span(sql, self.redact_sql_option_keywords.clone(), &mut span);
384 }
385 }
386 }
387 if let Some(current_session) = self.session.as_ref() {
388 record_user_in_span(¤t_session.user(), &mut span);
389 }
390 span
391 }
392
393 async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
397 let span = self.root_span_for_msg(&msg);
398 let weak_session = self
399 .session
400 .as_ref()
401 .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
402
403 let fut = Box::pin(self.do_process_inner(msg));
408
409 let fut = async move {
411 if let Some(session) = weak_session {
412 CURRENT_SESSION.scope(session, fut).await
413 } else {
414 fut.await
415 }
416 };
417
418 let fut = async move {
420 AssertUnwindSafe(fut)
421 .rw_catch_unwind()
422 .await
423 .unwrap_or_else(|payload| {
424 Err(PsqlError::Panic(
425 panic_message::panic_message(&payload).to_owned(),
426 ))
427 })
428 };
429
430 let fut = async move {
432 let period = *SLOW_QUERY_LOG_PERIOD;
433 let mut fut = std::pin::pin!(fut);
434 let mut elapsed = Duration::ZERO;
435
436 loop {
438 match tokio::time::timeout(period, &mut fut).await {
439 Ok(result) => break result,
440 Err(_) => {
441 elapsed += period;
442 tracing::info!(
443 target: PGWIRE_SLOW_QUERY_LOG,
444 elapsed = %format_args!("{}ms", elapsed.as_millis()),
445 "slow query"
446 );
447 }
448 }
449 }
450 };
451
452 let fut = async move {
454 if !tracing::Span::current().is_none() {
455 tracing::info!(
456 target: PGWIRE_QUERY_LOG,
457 status = "started",
458 );
459 }
460
461 let start = Instant::now();
462 let result = fut.await;
463 let elapsed = start.elapsed();
464
465 if let Err(error) = &result {
469 if cfg!(debug_assertions) && !Deployment::current().is_ci() {
470 tracing::error!(error = ?error.as_report(), "error when process message");
476 } else {
477 tracing::error!(error = %error.as_report(), "error when process message");
478 }
479 }
480
481 if !tracing::Span::current().is_none() {
484 tracing::info!(
485 target: PGWIRE_QUERY_LOG,
486 status = if result.is_ok() { "ok" } else { "err" },
487 time = %format_args!("{}ms", elapsed.as_millis()),
488 );
489 }
490
491 result
492 };
493
494 let fut = fut.instrument(span);
496
497 match fut.await {
499 Ok(()) => Some(()),
500 Err(e) => {
501 match e {
502 PsqlError::IoError(io_err) => {
503 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
504 return None;
505 }
506 }
507
508 PsqlError::SslError(_) => {
509 return None;
512 }
513
514 PsqlError::StartupError(_) | PsqlError::PasswordError => {
515 self.stream
516 .write_no_flush(BeMessage::ErrorResponse {
517 error: &e,
518 pretty: false,
521 severity: Some(Severity::Fatal),
522 })
523 .ok()?;
524 let _ = self.stream.flush().await;
525 return None;
526 }
527
528 PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
529 self.stream
530 .write_no_flush(BeMessage::ErrorResponse {
531 error: &e,
532 pretty: true,
533 severity: None,
534 })
535 .ok()?;
536 self.ready_for_query().ok()?;
537 }
538
539 PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
540 self.stream
541 .write_no_flush(BeMessage::ErrorResponse {
542 error: &e,
543 pretty: true,
544 severity: None,
545 })
546 .ok()?;
547 let _ = self.stream.flush().await;
548
549 return None;
554 }
555
556 PsqlError::Uncategorized(_)
557 | PsqlError::ExtendedPrepareError(_)
558 | PsqlError::ExtendedExecuteError(_) => {
559 self.stream
560 .write_no_flush(BeMessage::ErrorResponse {
561 error: &e,
562 pretty: true,
563 severity: None,
564 })
565 .ok()?;
566 }
567 }
568 let _ = self.stream.flush().await;
569 Some(())
570 }
571 }
572 }
573
574 async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
575 if self.ignore_util_sync {
577 if let FeMessage::Sync = msg {
578 } else {
579 tracing::trace!("ignore message {:?} until sync.", msg);
580 return Ok(());
581 }
582 }
583
584 match msg {
585 FeMessage::Gss => self.process_gss_msg().await?,
586 FeMessage::Ssl => self.process_ssl_msg().await?,
587 FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
588 FeMessage::Password(msg) => self.process_password_msg(msg).await?,
589 FeMessage::Query(query_msg) => {
590 let sql = Arc::from(query_msg.get_sql()?);
591 drop(query_msg);
593 self.process_query_msg(sql).await?
594 }
595 FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
596 FeMessage::Terminate => self.process_terminate(),
597 FeMessage::Parse(m) => {
598 if let Err(err) = self.process_parse_msg(m).await {
599 self.ignore_util_sync = true;
600 return Err(err);
601 }
602 }
603 FeMessage::Bind(m) => {
604 if let Err(err) = self.process_bind_msg(m) {
605 self.ignore_util_sync = true;
606 return Err(err);
607 }
608 }
609 FeMessage::Execute(m) => {
610 if let Err(err) = self.process_execute_msg(m).await {
611 self.ignore_util_sync = true;
612 return Err(err);
613 }
614 }
615 FeMessage::Describe(m) => {
616 if let Err(err) = self.process_describe_msg(m) {
617 self.ignore_util_sync = true;
618 return Err(err);
619 }
620 }
621 FeMessage::Sync => {
622 self.ignore_util_sync = false;
623 self.ready_for_query()?
624 }
625 FeMessage::Close(m) => {
626 if let Err(err) = self.process_close_msg(m) {
627 self.ignore_util_sync = true;
628 return Err(err);
629 }
630 }
631 FeMessage::Flush => {
632 if let Err(err) = self.stream.flush().await {
633 self.ignore_util_sync = true;
634 return Err(err.into());
635 }
636 }
637 FeMessage::HealthCheck => self.process_health_check(),
638 FeMessage::ServerThrottle(reason) => match reason {
639 ServerThrottleReason::TooLargeMessage => {
640 return Err(PsqlError::ServerThrottle(format!(
641 "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
642 self.message_memory_manager.max_filter_bytes
643 )));
644 }
645 ServerThrottleReason::TooManyMemoryUsage => {
646 return Err(PsqlError::ServerThrottle(format!(
647 "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
648 self.message_memory_manager.max_running_bytes
649 )));
650 }
651 },
652 }
653 self.stream.flush().await?;
654 Ok(())
655 }
656
657 pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
658 match self.state {
659 PgProtocolState::Startup => self
660 .stream
661 .read_startup()
662 .await
663 .map(|message: FeMessage| (message, None)),
664 PgProtocolState::Regular => {
665 self.stream.read_header().await?;
666 let guard = if let Some(ref header) = self.stream.read_header {
667 let payload_len = std::cmp::max(header.payload_len, 0) as u64;
668 let (reason, guard) = self.message_memory_manager.add(payload_len);
669 if let Some(reason) = reason {
670 drop(guard);
672 self.stream.skip_body().await?;
673 return Ok((FeMessage::ServerThrottle(reason), None));
674 }
675 guard
676 } else {
677 None
678 };
679 let message = self.stream.read_body().await?;
680 Ok((message, guard))
681 }
682 }
683 }
684
685 fn ready_for_query(&mut self) -> io::Result<()> {
687 self.stream.write_no_flush(BeMessage::ReadyForQuery(
688 self.session
689 .as_ref()
690 .map(|s| s.transaction_status())
691 .unwrap_or(TransactionStatus::Idle),
692 ))
693 }
694
695 async fn process_gss_msg(&mut self) -> PsqlResult<()> {
696 self.stream.write(BeMessage::EncryptionResponseNo).await?;
698 Ok(())
699 }
700
701 async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
702 if let Some(context) = self.tls_context.as_ref() {
703 self.stream.write(BeMessage::EncryptionResponseSsl).await?;
706 self.stream.upgrade_to_ssl(context).await?;
707 } else {
708 self.stream.write(BeMessage::EncryptionResponseNo).await?;
710 }
711
712 Ok(())
713 }
714
715 async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
716 if let Some(ref tls_config) = self.tls_config
718 && tls_config.enforce_ssl
719 && !self.stream.is_ssl_connection().await
720 {
721 return Err(PsqlError::StartupError(
722 "SSL connection is required but not established".into(),
723 ));
724 }
725
726 let db_name = msg
727 .config
728 .get("database")
729 .cloned()
730 .unwrap_or_else(|| "dev".to_owned());
731 let user_name = msg
732 .config
733 .get("user")
734 .cloned()
735 .unwrap_or_else(|| "root".to_owned());
736
737 let session = self
738 .session_mgr
739 .connect(&db_name, &user_name, self.peer_addr.clone())
740 .map_err(|e| PsqlError::StartupError(e.into()))?;
741
742 if let Some(options) = msg.config.get("options") {
743 for (key, value) in parse_options(options)? {
744 session
745 .set_config(&key, value)
746 .map_err(|e| PsqlError::StartupError(e.into()))?;
747 }
748 }
749 let application_name = msg.config.get("application_name");
751 if let Some(application_name) = application_name {
752 session
753 .set_config("application_name", application_name.clone())
754 .map_err(|e| PsqlError::StartupError(e.into()))?;
755 }
756
757 match session.user_authenticator() {
758 UserAuthenticator::None => {
759 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
760
761 self.stream
764 .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
765
766 self.stream.write_no_flush(BeMessage::ParameterStatus(
767 BeParameterStatusMessage::TimeZone(
768 &session
769 .get_config("timezone")
770 .map_err(|e| PsqlError::StartupError(e.into()))?,
771 ),
772 ))?;
773 self.stream
774 .write_parameter_status_msg_no_flush(&ParameterStatus {
775 application_name: application_name.cloned(),
776 })?;
777 self.ready_for_query()?;
778 }
779 UserAuthenticator::ClearText(_)
780 | UserAuthenticator::OAuth { .. }
781 | UserAuthenticator::Ldap(..) => {
782 self.stream
783 .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
784 }
785 UserAuthenticator::Md5WithSalt { salt, .. } => {
786 self.stream
787 .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
788 }
789 }
790
791 self.session = Some(session);
792 self.state = PgProtocolState::Regular;
793 Ok(())
794 }
795
796 async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
797 let session = self.session.as_ref().unwrap();
798 let authenticator = session.user_authenticator();
799 authenticator.authenticate(&msg.password).await?;
800 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
801 let timezone = session
802 .get_config("timezone")
803 .map_err(|e| PsqlError::StartupError(e.into()))?;
804 self.stream.write_no_flush(BeMessage::ParameterStatus(
805 BeParameterStatusMessage::TimeZone(&timezone),
806 ))?;
807 self.stream
808 .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
809 self.ready_for_query()?;
810 self.state = PgProtocolState::Regular;
811 Ok(())
812 }
813
814 fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
815 let session_id = (m.target_process_id, m.target_secret_key);
816 tracing::trace!("cancel query in session: {:?}", session_id);
817 self.session_mgr.cancel_queries_in_session(session_id);
818 self.session_mgr.cancel_creating_jobs_in_session(session_id);
819 self.is_terminate = true;
820 Ok(())
821 }
822
823 async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
824 let truncated_sql =
825 get_redacted_and_truncated_sql(&sql, self.redact_sql_option_keywords.clone());
826 let session = self.session.clone().unwrap();
827
828 session.check_idle_in_transaction_timeout()?;
829 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
831 self.inner_process_query_msg(sql, session.clone()).await
832 }
833
834 async fn inner_process_query_msg(
835 &mut self,
836 sql: Arc<str>,
837 session: Arc<SM::Session>,
838 ) -> PsqlResult<()> {
839 let stmts =
841 Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
842 drop(sql);
844 if stmts.is_empty() {
845 self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
846 }
847
848 for stmt in stmts {
850 self.inner_process_query_msg_one_stmt(stmt, session.clone())
851 .await?;
852 }
853 self.ready_for_query()?;
856 Ok(())
857 }
858
859 async fn inner_process_query_msg_one_stmt(
860 &mut self,
861 stmt: Statement,
862 session: Arc<SM::Session>,
863 ) -> PsqlResult<()> {
864 let session = session.clone();
865
866 let res = session.clone().run_one_query(stmt, Format::Text).await;
868
869 while let Some(notice) = session.next_notice().now_or_never() {
871 self.stream
872 .write_no_flush(BeMessage::NoticeResponse(¬ice))?;
873 }
874
875 let mut res = res.map_err(|e| PsqlError::SimpleQueryError(e.into()))?;
876
877 for notice in res.notices() {
878 self.stream
879 .write_no_flush(BeMessage::NoticeResponse(notice))?;
880 }
881
882 let status = res.status();
883 if let Some(ref application_name) = status.application_name {
884 self.stream.write_no_flush(BeMessage::ParameterStatus(
885 BeParameterStatusMessage::ApplicationName(application_name),
886 ))?;
887 }
888
889 if res.is_copy_query_to_stdout() {
890 self.stream
891 .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
892 let mut count = 0;
893 while let Some(row_set) = res.values_stream().next().await {
894 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
895 for row in row_set {
896 self.stream.write_no_flush(BeMessage::CopyData(&row))?;
897 count += 1;
898 }
899 }
900
901 self.stream.write_no_flush(BeMessage::CopyDone)?;
902
903 res.run_callback().await?;
905
906 self.stream
907 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
908 stmt_type: res.stmt_type(),
909 rows_cnt: count,
910 }))?;
911 } else if res.is_query() {
912 self.stream
913 .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
914
915 let mut rows_cnt = 0;
916
917 while let Some(row_set) = res.values_stream().next().await {
918 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
919 for row in row_set {
920 self.stream.write_no_flush(BeMessage::DataRow(&row))?;
921 rows_cnt += 1;
922 }
923 }
924
925 res.run_callback().await?;
927
928 self.stream
929 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
930 stmt_type: res.stmt_type(),
931 rows_cnt,
932 }))?;
933 } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
934 let first_row_set = res.values_stream().next().await;
935 let first_row_set = match first_row_set {
936 None => {
937 return Err(PsqlError::Uncategorized(
938 anyhow::anyhow!("no affected rows in output").into(),
939 ));
940 }
941 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
942 };
943 let affected_rows_str = first_row_set[0].values()[0]
944 .as_ref()
945 .expect("compute node should return affected rows in output");
946
947 assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
948 let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
949 .unwrap()
950 .parse()
951 .unwrap_or_default();
952
953 res.run_callback().await?;
955
956 self.stream
957 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
958 stmt_type: res.stmt_type(),
959 rows_cnt: affected_rows_cnt,
960 }))?;
961 } else {
962 res.run_callback().await?;
964
965 self.stream
966 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
967 stmt_type: res.stmt_type(),
968 rows_cnt: 0,
969 }))?;
970 }
971
972 Ok(())
973 }
974
975 fn process_terminate(&mut self) {
976 self.is_terminate = true;
977 }
978
979 fn process_health_check(&mut self) {
980 tracing::debug!("health check");
981 self.is_terminate = true;
982 }
983
984 async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
985 let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
986 let session = self.session.clone().unwrap();
987 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
988 let type_ids = std::mem::take(&mut msg.type_ids);
989 drop(msg);
991 self.inner_process_parse_msg(session, sql, statement_name, type_ids)
992 .await?;
993 Ok(())
994 }
995
996 async fn inner_process_parse_msg(
997 &mut self,
998 session: Arc<SM::Session>,
999 sql: Arc<str>,
1000 statement_name: String,
1001 type_ids: Vec<i32>,
1002 ) -> PsqlResult<()> {
1003 if statement_name.is_empty() {
1004 self.unnamed_prepare_statement.take();
1007 } else if self.prepare_statement_store.contains_key(&statement_name) {
1008 return Err(PsqlError::ExtendedPrepareError(
1009 "Duplicated statement name".into(),
1010 ));
1011 }
1012
1013 let stmt = {
1014 let stmts = Parser::parse_sql(&sql)
1015 .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
1016 if stmts.len() > 1 {
1017 return Err(PsqlError::ExtendedPrepareError(
1018 "Only one statement is allowed in extended query mode".into(),
1019 ));
1020 }
1021
1022 stmts.into_iter().next()
1023 };
1024
1025 let param_types: Vec<Option<DataType>> = type_ids
1026 .iter()
1027 .map(|&id| {
1028 if id == 0 {
1031 Ok(None)
1032 } else {
1033 DataType::from_oid(id)
1034 .map(Some)
1035 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
1036 }
1037 })
1038 .try_collect()?;
1039
1040 let prepare_statement = session
1041 .parse(stmt, param_types)
1042 .await
1043 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))?;
1044 let prepare_statement = PreparedStatementData {
1045 statement: prepare_statement,
1046 sql,
1047 };
1048
1049 if statement_name.is_empty() {
1050 self.unnamed_prepare_statement.replace(prepare_statement);
1051 } else {
1052 self.prepare_statement_store
1053 .insert(statement_name.clone(), prepare_statement);
1054 }
1055
1056 self.statement_portal_dependency
1057 .entry(statement_name)
1058 .or_default()
1059 .clear();
1060
1061 self.stream.write_no_flush(BeMessage::ParseComplete)?;
1062 Ok(())
1063 }
1064
1065 fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
1066 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
1067 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1068 let session = self.session.clone().unwrap();
1069
1070 if self.portal_store.contains_key(&portal_name) {
1071 return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
1072 }
1073
1074 let prepare_statement = self.get_statement_data(&statement_name)?.clone();
1075
1076 let result_formats = msg
1077 .result_format_codes
1078 .iter()
1079 .map(|&format_code| Format::from_i16(format_code))
1080 .try_collect()?;
1081 let param_formats = msg
1082 .param_format_codes
1083 .iter()
1084 .map(|&format_code| Format::from_i16(format_code))
1085 .try_collect()?;
1086
1087 let portal = session
1088 .bind(
1089 prepare_statement.statement,
1090 msg.params,
1091 param_formats,
1092 result_formats,
1093 )
1094 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1095 let portal = PortalData {
1096 portal,
1097 sql: prepare_statement.sql,
1098 };
1099
1100 if portal_name.is_empty() {
1101 self.result_cache.remove(&portal_name);
1102 self.unnamed_portal.replace(portal);
1103 } else {
1104 assert!(
1105 !self.result_cache.contains_key(&portal_name),
1106 "Named portal never can be overridden."
1107 );
1108 self.portal_store.insert(portal_name.clone(), portal);
1109 }
1110
1111 self.statement_portal_dependency
1112 .get_mut(&statement_name)
1113 .unwrap()
1114 .push(portal_name);
1115
1116 self.stream.write_no_flush(BeMessage::BindComplete)?;
1117 Ok(())
1118 }
1119
1120 async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1121 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1122 let row_max = msg.max_rows as usize;
1123 drop(msg);
1124 let session = self.session.clone().unwrap();
1125
1126 match self.result_cache.remove(&portal_name) {
1127 Some(mut result_cache) => {
1128 assert!(self.portal_store.contains_key(&portal_name));
1129
1130 let is_consume_completed =
1131 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1132
1133 if !is_consume_completed {
1134 self.result_cache.insert(portal_name, result_cache);
1135 }
1136 }
1137 _ => {
1138 let portal = self.get_portal_data(&portal_name)?.clone();
1139 let sql = format!("{}", portal.portal);
1140 let truncated_sql =
1141 get_redacted_and_truncated_sql(&sql, self.redact_sql_option_keywords.clone());
1142 drop(sql);
1143
1144 session.check_idle_in_transaction_timeout()?;
1145 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1147 let result = session.clone().execute(portal.portal).await;
1148
1149 let pg_response = result.map_err(|e| PsqlError::ExtendedExecuteError(e.into()))?;
1150 let mut result_cache = ResultCache::new(pg_response);
1151 let is_consume_completed =
1152 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1153 if !is_consume_completed {
1154 self.result_cache.insert(portal_name, result_cache);
1155 }
1156 }
1157 }
1158
1159 Ok(())
1160 }
1161
1162 fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1163 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1164 let session = self.session.clone().unwrap();
1165 assert!(msg.kind == b'S' || msg.kind == b'P');
1169 if msg.kind == b'S' {
1170 let prepare_statement = self.get_statement(&name)?;
1171
1172 let (param_types, row_descriptions) = self
1173 .session
1174 .clone()
1175 .unwrap()
1176 .describe_statement(prepare_statement)
1177 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1178 self.stream.write_no_flush(BeMessage::ParameterDescription(
1179 ¶m_types.iter().map(|t| t.to_oid()).collect_vec(),
1180 ))?;
1181
1182 if row_descriptions.is_empty() {
1183 self.stream.write_no_flush(BeMessage::NoData)?;
1186 } else {
1187 self.stream
1188 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1189 }
1190 } else if msg.kind == b'P' {
1191 let portal = self.get_portal(&name)?;
1192
1193 let row_descriptions = session
1194 .describe_portal(portal)
1195 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1196
1197 if row_descriptions.is_empty() {
1198 self.stream.write_no_flush(BeMessage::NoData)?;
1201 } else {
1202 self.stream
1203 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1204 }
1205 }
1206 Ok(())
1207 }
1208
1209 fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1210 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1211 assert!(msg.kind == b'S' || msg.kind == b'P');
1212 if msg.kind == b'S' {
1213 if name.is_empty() {
1214 self.unnamed_prepare_statement = None;
1215 } else {
1216 self.prepare_statement_store.remove(&name);
1217 }
1218 for portal_name in self
1219 .statement_portal_dependency
1220 .remove(&name)
1221 .unwrap_or_default()
1222 {
1223 self.remove_portal(&portal_name);
1224 }
1225 } else if msg.kind == b'P' {
1226 self.remove_portal(&name);
1227 }
1228 self.stream.write_no_flush(BeMessage::CloseComplete)?;
1229 Ok(())
1230 }
1231
1232 fn remove_portal(&mut self, portal_name: &str) {
1233 if portal_name.is_empty() {
1234 self.unnamed_portal = None;
1235 } else {
1236 self.portal_store.remove(portal_name);
1237 }
1238 self.result_cache.remove(portal_name);
1239 }
1240
1241 fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1242 Ok(self.get_portal_data(portal_name)?.portal.clone())
1243 }
1244
1245 fn get_portal_data(
1246 &self,
1247 portal_name: &str,
1248 ) -> PsqlResult<&PortalData<<SM::Session as Session>::Portal>> {
1249 if portal_name.is_empty() {
1250 self.unnamed_portal
1251 .as_ref()
1252 .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))
1253 } else {
1254 self.portal_store.get(portal_name).ok_or_else(|| {
1255 PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1256 })
1257 }
1258 }
1259
1260 fn get_statement(
1261 &self,
1262 statement_name: &str,
1263 ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1264 Ok(self.get_statement_data(statement_name)?.statement.clone())
1265 }
1266
1267 fn get_statement_data(
1268 &self,
1269 statement_name: &str,
1270 ) -> PsqlResult<&PreparedStatementData<<SM::Session as Session>::PreparedStatement>> {
1271 if statement_name.is_empty() {
1272 self.unnamed_prepare_statement.as_ref().ok_or_else(|| {
1273 PsqlError::Uncategorized("unnamed prepare statement not found".into())
1274 })
1275 } else {
1276 self.prepare_statement_store
1277 .get(statement_name)
1278 .ok_or_else(|| {
1279 PsqlError::Uncategorized(
1280 format!("Prepare statement {} not found", statement_name).into(),
1281 )
1282 })
1283 }
1284 }
1285
1286 fn get_portal_sql(&self, portal_name: &str) -> PsqlResult<Arc<str>> {
1287 Ok(self.get_portal_data(portal_name)?.sql.clone())
1288 }
1289}
1290
1291enum PgStreamInner<S> {
1292 Placeholder,
1294 Unencrypted(S),
1296 Ssl(SslStream<S>),
1298}
1299
1300pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1302impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1303
1304pub struct PgStream<S> {
1309 stream: Arc<Mutex<PgStreamInner<S>>>,
1311 write_buf: BytesMut,
1313 read_header: Option<FeMessageHeader>,
1314}
1315
1316impl<S> PgStream<S> {
1317 pub fn new(stream: S) -> Self {
1319 const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1320
1321 Self {
1322 stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1323 write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1324 read_header: None,
1325 }
1326 }
1327
1328 async fn is_ssl_connection(&self) -> bool {
1330 let stream = self.stream.lock().await;
1331 matches!(*stream, PgStreamInner::Ssl(_))
1332 }
1333}
1334
1335impl<S> Clone for PgStream<S> {
1336 fn clone(&self) -> Self {
1337 Self {
1338 stream: Arc::clone(&self.stream),
1339 write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1340 read_header: self.read_header.clone(),
1341 }
1342 }
1343}
1344
1345#[derive(Debug, Default, Clone)]
1362pub struct ParameterStatus {
1363 pub application_name: Option<String>,
1364}
1365
1366impl<S> PgStream<S>
1367where
1368 S: PgByteStream,
1369{
1370 async fn read_startup(&mut self) -> io::Result<FeMessage> {
1371 let mut stream = self.stream.lock().await;
1372 match &mut *stream {
1373 PgStreamInner::Placeholder => unreachable!(),
1374 PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1375 PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1376 }
1377 }
1378
1379 async fn read_header(&mut self) -> io::Result<()> {
1380 let mut stream = self.stream.lock().await;
1381 match &mut *stream {
1382 PgStreamInner::Placeholder => unreachable!(),
1383 PgStreamInner::Unencrypted(stream) => {
1384 self.read_header = Some(FeMessage::read_header(stream).await?);
1385 Ok(())
1386 }
1387 PgStreamInner::Ssl(ssl_stream) => {
1388 self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1389 Ok(())
1390 }
1391 }
1392 }
1393
1394 async fn read_body(&mut self) -> io::Result<FeMessage> {
1395 let mut stream = self.stream.lock().await;
1396 let header = self
1397 .read_header
1398 .take()
1399 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1400 match &mut *stream {
1401 PgStreamInner::Placeholder => unreachable!(),
1402 PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1403 PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1404 }
1405 }
1406
1407 async fn skip_body(&mut self) -> io::Result<()> {
1408 let mut stream = self.stream.lock().await;
1409 let header = self
1410 .read_header
1411 .take()
1412 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1413 match &mut *stream {
1414 PgStreamInner::Placeholder => unreachable!(),
1415 PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1416 PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1417 }
1418 }
1419
1420 fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1421 self.write_no_flush(BeMessage::ParameterStatus(
1422 BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1423 ))?;
1424 self.write_no_flush(BeMessage::ParameterStatus(
1425 BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1426 ))?;
1427 self.write_no_flush(BeMessage::ParameterStatus(
1428 BeParameterStatusMessage::ServerVersion(PG_VERSION),
1429 ))?;
1430 if let Some(application_name) = &status.application_name {
1431 self.write_no_flush(BeMessage::ParameterStatus(
1432 BeParameterStatusMessage::ApplicationName(application_name),
1433 ))?;
1434 }
1435 Ok(())
1436 }
1437
1438 pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1439 BeMessage::write(&mut self.write_buf, message)
1440 }
1441
1442 async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1443 self.write_no_flush(message)?;
1444 self.flush().await?;
1445 Ok(())
1446 }
1447
1448 async fn flush(&mut self) -> io::Result<()> {
1449 let mut stream = self.stream.lock().await;
1450 match &mut *stream {
1451 PgStreamInner::Placeholder => unreachable!(),
1452 PgStreamInner::Unencrypted(stream) => {
1453 stream.write_all(&self.write_buf).await?;
1454 stream.flush().await?;
1455 }
1456 PgStreamInner::Ssl(ssl_stream) => {
1457 ssl_stream.write_all(&self.write_buf).await?;
1458 ssl_stream.flush().await?;
1459 }
1460 }
1461 self.write_buf.clear();
1462 Ok(())
1463 }
1464}
1465
1466impl<S> PgStream<S>
1467where
1468 S: PgByteStream,
1469{
1470 async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1472 let mut stream = self.stream.lock().await;
1473
1474 match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1475 PgStreamInner::Unencrypted(unencrypted_stream) => {
1476 let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1477 let mut ssl_stream =
1478 tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1479
1480 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1481 tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1482 let _ = ssl_stream.shutdown().await;
1483 return Err(e.into());
1484 }
1485
1486 *stream = PgStreamInner::Ssl(ssl_stream);
1487 }
1488 PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1489 PgStreamInner::Placeholder => unreachable!(),
1490 }
1491
1492 Ok(())
1493 }
1494}
1495
1496fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1497 let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1498
1499 let key_path = &tls_config.key;
1500 let cert_path = &tls_config.cert;
1501
1502 acceptor
1505 .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1506 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1507 acceptor
1508 .set_ca_file(cert_path)
1509 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1510 acceptor
1511 .set_certificate_chain_file(cert_path)
1512 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1513 let acceptor = acceptor.build();
1514
1515 Ok(acceptor.into_context())
1516}
1517
1518pub mod truncated_fmt {
1519 use std::fmt::*;
1520
1521 struct TruncatedFormatter<'a, 'b> {
1522 remaining: usize,
1523 finished: bool,
1524 f: &'a mut Formatter<'b>,
1525 }
1526 impl Write for TruncatedFormatter<'_, '_> {
1527 fn write_str(&mut self, s: &str) -> Result {
1528 if self.finished {
1529 return Ok(());
1530 }
1531
1532 if self.remaining < s.len() {
1533 let actual = s.floor_char_boundary(self.remaining);
1534 self.f.write_str(&s[0..actual])?;
1535 self.remaining -= actual;
1536 self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1537 self.finished = true; } else {
1539 self.f.write_str(s)?;
1540 self.remaining -= s.len();
1541 }
1542 Ok(())
1543 }
1544 }
1545
1546 pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1547
1548 impl<T> Debug for TruncatedFmt<'_, T>
1549 where
1550 T: Debug,
1551 {
1552 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1553 TruncatedFormatter {
1554 remaining: self.1,
1555 finished: false,
1556 f,
1557 }
1558 .write_fmt(format_args!("{:?}", self.0))
1559 }
1560 }
1561
1562 impl<T> Display for TruncatedFmt<'_, T>
1563 where
1564 T: Display,
1565 {
1566 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1567 TruncatedFormatter {
1568 remaining: self.1,
1569 finished: false,
1570 f,
1571 }
1572 .write_fmt(format_args!("{}", self.0))
1573 }
1574 }
1575
1576 #[cfg(test)]
1577 mod tests {
1578 use super::*;
1579
1580 #[test]
1581 fn test_trunc_utf8() {
1582 assert_eq!(
1583 format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1584 "select '...(truncated,14)",
1585 );
1586 }
1587 }
1588}
1589
1590fn parse_options(options: &str) -> PsqlResult<Vec<(String, String)>> {
1603 let mut args = Vec::new();
1604 let mut current_arg = String::new();
1605 let mut chars = options.chars().peekable();
1606
1607 while let Some(c) = chars.next() {
1608 if c == '\\' {
1609 if let Some(next_c) = chars.next() {
1610 current_arg.push(next_c);
1611 }
1612 } else if c.is_ascii_whitespace() {
1613 if !current_arg.is_empty() {
1614 args.push(std::mem::take(&mut current_arg));
1615 }
1616 } else {
1617 current_arg.push(c);
1618 }
1619 }
1620 if !current_arg.is_empty() {
1621 args.push(current_arg);
1622 }
1623
1624 let mut args_iter = args.into_iter();
1625 let mut config = Vec::new();
1626
1627 while let Some(arg) = args_iter.next() {
1628 if arg == "-c" {
1629 if let Some(config_str) = args_iter.next() {
1630 if let Some((key, value)) = config_str.split_once('=') {
1631 let key = key.replace("-", "_");
1632 config.push((key, value.to_owned()));
1633 } else {
1634 return Err(PsqlError::StartupError(
1635 format!("invalid config format: {}", config_str).into(),
1636 ));
1637 }
1638 } else {
1639 return Err(PsqlError::StartupError("missing argument for -c".into()));
1640 }
1641 } else if let Some(config_str) = arg.strip_prefix("--") {
1642 if let Some((key, value)) = config_str.split_once('=') {
1643 let key = key.replace("-", "_");
1644 config.push((key, value.to_owned()));
1645 } else {
1646 return Err(PsqlError::StartupError(
1647 format!("invalid config format: {}", config_str).into(),
1648 ));
1649 }
1650 } else {
1651 tracing::warn!(
1652 arg,
1653 "ignoring unrecognized option for backward compatibility"
1654 );
1655 }
1656 }
1657 Ok(config)
1658}
1659
1660#[cfg(test)]
1661mod tests {
1662 use std::collections::HashSet;
1663
1664 use super::*;
1665
1666 #[test]
1667 fn test_redact_parsable_sql() {
1668 let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1669 let sql = r"
1670 create source temp (k bigint, v varchar) with (
1671 connector = 'datagen',
1672 v1 = 123,
1673 v2 = 'with',
1674 v3 = false,
1675 v4 = '',
1676 ) FORMAT plain ENCODE json (a='1',b='2')
1677 ";
1678 assert_eq!(
1679 redact_sql(sql, keywords),
1680 "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1681 );
1682 }
1683
1684 #[test]
1685 fn test_redact_user_password_sql() {
1686 let keywords = Arc::new(HashSet::from(["password".into()]));
1687
1688 assert_eq!(
1689 redact_sql("ALTER USER WITH PASSWORD 'rw_password_2'", keywords.clone()),
1690 "ALTER USER WITH PASSWORD [REDACTED]"
1691 );
1692 assert_eq!(
1693 redact_sql(
1694 "ALTER USER foo WITH ENCRYPTED PASSWORD 'md5827ccb0eea8a706c4c34a16891f84e7b'",
1695 keywords.clone(),
1696 ),
1697 "ALTER USER foo WITH ENCRYPTED PASSWORD [REDACTED]"
1698 );
1699 assert_eq!(
1700 redact_sql("CREATE USER foo WITH PASSWORD 'rw_password_2'", keywords),
1701 "CREATE USER foo WITH PASSWORD [REDACTED]"
1702 );
1703 }
1704
1705 #[test]
1706 fn test_parse_options() {
1707 assert_eq!(parse_options("").unwrap(), vec![]);
1708 assert_eq!(
1709 parse_options("-c a=1 -c b=2").unwrap(),
1710 vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1711 );
1712 assert_eq!(
1713 parse_options("-c key=value").unwrap(),
1714 vec![("key".into(), "value".into())]
1715 );
1716 assert_eq!(
1718 parse_options("-c key='value'").unwrap(),
1719 vec![("key".into(), "'value'".into())]
1720 );
1721
1722 assert_eq!(
1724 parse_options(r#"-c key=value\ with\ spaces"#).unwrap(),
1725 vec![("key".into(), "value with spaces".into())]
1726 );
1727 assert_eq!(
1728 parse_options(r#"-c search_path=my\ schema"#).unwrap(),
1729 vec![("search_path".into(), "my schema".into())]
1730 );
1731
1732 assert!(parse_options("-c").is_err());
1733 assert!(parse_options("-c foo").is_err()); assert!(parse_options("--foo").is_err()); assert_eq!(
1737 parse_options("--foo=bar").unwrap(),
1738 vec![("foo".into(), "bar".into())]
1739 );
1740 assert_eq!(
1741 parse_options(r#"--foo=bar\ baz"#).unwrap(),
1742 vec![("foo".into(), "bar baz".into())]
1743 );
1744 assert_eq!(
1745 parse_options("-c a=1 --b=2").unwrap(),
1746 vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1747 );
1748 assert_eq!(
1750 parse_options(r#"-c a=b\"#).unwrap(),
1751 vec![("a".into(), "b".into())]
1752 );
1753 }
1754}