1use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::error_or_notice::Severity;
46use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
47use crate::net::AddressRef;
48use crate::pg_extended::ResultCache;
49use crate::pg_message::{
50 BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
51 FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
52 FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
53};
54use crate::pg_server::{Session, SessionManager, UserAuthenticator};
55use crate::types::Format;
56
57static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
60 LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
61 Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
62 _ => 65536,
63 });
64
65tokio::task_local! {
66 pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
68}
69
70pub struct PgProtocol<S, SM>
73where
74 SM: SessionManager,
75{
76 stream: PgStream<S>,
78 state: PgProtocolState,
80 is_terminate: bool,
82
83 session_mgr: Arc<SM>,
84 session: Option<Arc<SM::Session>>,
85
86 result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
87 unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
88 prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
89 unnamed_portal: Option<<SM::Session as Session>::Portal>,
90 portal_store: HashMap<String, <SM::Session as Session>::Portal>,
91 statement_portal_dependency: HashMap<String, Vec<String>>,
94
95 tls_context: Option<SslContext>,
98
99 tls_config: Option<TlsConfig>,
101
102 ignore_util_sync: bool,
105
106 peer_addr: AddressRef,
108
109 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
110 message_memory_manager: MessageMemoryManagerRef,
111}
112
113#[derive(Debug, Clone)]
115pub struct TlsConfig {
116 pub cert: String,
118 pub key: String,
120 pub enforce_ssl: bool,
122}
123
124impl TlsConfig {
125 pub fn new_default() -> anyhow::Result<Option<Self>> {
126 let cert = std::env::var("RW_SSL_CERT").ok();
127 let key = std::env::var("RW_SSL_KEY").ok();
128 let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
129
130 if cert.is_some() ^ key.is_some() {
131 return Err(anyhow::anyhow!(
132 "RW_SSL_CERT and RW_SSL_KEY must be set together"
133 ));
134 }
135
136 if enforce_ssl && cert.is_none() {
137 return Err(anyhow::anyhow!(
138 "RW_SSL_ENFORCE requires RW_SSL_CERT and RW_SSL_KEY to be set"
139 ));
140 }
141
142 let (Some(cert), Some(key)) = (cert, key) else {
143 return Ok(None);
144 };
145
146 tracing::info!(
147 "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
148 cert,
149 key,
150 enforce_ssl
151 );
152 Ok(Some(Self {
153 cert,
154 key,
155 enforce_ssl,
156 }))
157 }
158}
159
160impl<S, SM> Drop for PgProtocol<S, SM>
161where
162 SM: SessionManager,
163{
164 fn drop(&mut self) {
165 if let Some(session) = &self.session {
166 self.session_mgr.end_session(session);
168 }
169 }
170}
171
172enum PgProtocolState {
174 Startup,
175 Regular,
176}
177
178pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
182 let without_null = if b.last() == Some(&0) {
183 &b[..b.len() - 1]
184 } else {
185 &b[..]
186 };
187 std::str::from_utf8(without_null)
188}
189
190fn get_redacted_and_truncated_sql(
191 sql: &str,
192 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
193) -> String {
194 let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
195 && !keywords.is_empty()
196 {
197 redact_sql(sql, keywords)
198 } else {
199 sql.to_owned()
200 };
201 let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
202 truncated.to_string()
203}
204
205fn record_sql_in_span(
207 sql: &str,
208 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
209 span: &mut tracing::Span,
210) {
211 let redacted_and_truncated_sql =
212 get_redacted_and_truncated_sql(sql, redact_sql_option_keywords);
213 span.record("sql", tracing::field::display(&redacted_and_truncated_sql));
214}
215
216fn record_user_in_span(user: &str, span: &mut tracing::Span) {
217 span.record("user", tracing::field::display(user));
218}
219
220fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
222 match Parser::parse_sql(sql) {
223 Ok(sqls) => sqls
224 .into_iter()
225 .map(|sql| sql.to_redacted_string(keywords.clone()))
226 .join(";"),
227 Err(_) => sql.to_owned(),
228 }
229}
230
231#[derive(Clone)]
232pub struct ConnectionContext {
233 pub tls_config: Option<TlsConfig>,
234 pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
235 pub message_memory_manager: MessageMemoryManagerRef,
236}
237
238impl<S, SM> PgProtocol<S, SM>
239where
240 S: PgByteStream,
241 SM: SessionManager,
242{
243 pub fn new(
244 stream: S,
245 session_mgr: Arc<SM>,
246 peer_addr: AddressRef,
247 context: ConnectionContext,
248 ) -> Self {
249 let ConnectionContext {
250 tls_config,
251 redact_sql_option_keywords,
252 message_memory_manager,
253 } = context;
254 Self {
255 stream: PgStream::new(stream),
256 is_terminate: false,
257 state: PgProtocolState::Startup,
258 session_mgr,
259 session: None,
260 tls_context: tls_config
261 .as_ref()
262 .and_then(|e| build_ssl_ctx_from_config(e).ok()),
263 tls_config,
264 result_cache: Default::default(),
265 unnamed_prepare_statement: Default::default(),
266 prepare_statement_store: Default::default(),
267 unnamed_portal: Default::default(),
268 portal_store: Default::default(),
269 statement_portal_dependency: Default::default(),
270 ignore_util_sync: false,
271 peer_addr,
272 redact_sql_option_keywords,
273 message_memory_manager,
274 }
275 }
276
277 pub async fn run(&mut self) {
279 let mut notice_fut = None;
280
281 loop {
282 if notice_fut.is_none()
284 && let Some(session) = self.session.clone()
285 {
286 let mut stream = self.stream.clone();
287 notice_fut = Some(Box::pin(async move {
288 loop {
289 let notice = session.next_notice().await;
290 if let Err(e) = stream.write(BeMessage::NoticeResponse(¬ice)).await {
291 tracing::error!(error = %e.as_report(), notice, "failed to send notice");
292 }
293 }
294 }));
295 }
296
297 let process = std::pin::pin!(async {
299 let (msg, _memory_guard) = match self.read_message().await {
300 Ok(msg) => msg,
301 Err(e) => {
302 tracing::error!(error = %e.as_report(), "error when reading message");
303 return true; }
305 };
306 tracing::trace!(?msg, "received message");
307 self.process(msg).await
308 });
309
310 let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
311 tokio::select! {
312 _ = notice_fut => unreachable!(),
313 terminated = process => terminated,
314 }
315 } else {
316 process.await
317 };
318
319 if terminated {
320 break;
321 }
322 }
323 }
324
325 pub async fn process(&mut self, msg: FeMessage) -> bool {
327 self.do_process(msg).await.is_none() || self.is_terminate
328 }
329
330 fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
338 let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
339 return tracing::Span::none();
340 };
341
342 let mode = match msg {
343 FeMessage::Query(_) => "simple query",
344 FeMessage::Parse(_) => "extended query parse",
345 FeMessage::Execute(_) => "extended query execute",
346 _ => return tracing::Span::none(),
347 };
348
349 let mut span = tracing::info_span!(
350 target: PGWIRE_ROOT_SPAN_TARGET,
351 "handle_query",
352 mode,
353 session_id,
354 sql = tracing::field::Empty,
355 user = tracing::field::Empty,
356 );
357 if let Ok(sql) = msg.get_sql()
358 && let Some(sql) = sql
359 {
360 record_sql_in_span(sql, self.redact_sql_option_keywords.clone(), &mut span);
361 }
362 if let Some(current_session) = self.session.as_ref() {
363 record_user_in_span(¤t_session.user(), &mut span);
364 }
365 span
366 }
367
368 async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
372 let span = self.root_span_for_msg(&msg);
373 let weak_session = self
374 .session
375 .as_ref()
376 .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
377
378 let fut = Box::pin(self.do_process_inner(msg));
383
384 let fut = async move {
386 if let Some(session) = weak_session {
387 CURRENT_SESSION.scope(session, fut).await
388 } else {
389 fut.await
390 }
391 };
392
393 let fut = async move {
395 AssertUnwindSafe(fut)
396 .rw_catch_unwind()
397 .await
398 .unwrap_or_else(|payload| {
399 Err(PsqlError::Panic(
400 panic_message::panic_message(&payload).to_owned(),
401 ))
402 })
403 };
404
405 let fut = async move {
407 let period = *SLOW_QUERY_LOG_PERIOD;
408 let mut fut = std::pin::pin!(fut);
409 let mut elapsed = Duration::ZERO;
410
411 loop {
413 match tokio::time::timeout(period, &mut fut).await {
414 Ok(result) => break result,
415 Err(_) => {
416 elapsed += period;
417 tracing::info!(
418 target: PGWIRE_SLOW_QUERY_LOG,
419 elapsed = %format_args!("{}ms", elapsed.as_millis()),
420 "slow query"
421 );
422 }
423 }
424 }
425 };
426
427 let fut = async move {
429 if !tracing::Span::current().is_none() {
430 tracing::info!(
431 target: PGWIRE_QUERY_LOG,
432 status = "started",
433 );
434 }
435
436 let start = Instant::now();
437 let result = fut.await;
438 let elapsed = start.elapsed();
439
440 if let Err(error) = &result {
444 if cfg!(debug_assertions) && !Deployment::current().is_ci() {
445 tracing::error!(error = ?error.as_report(), "error when process message");
451 } else {
452 tracing::error!(error = %error.as_report(), "error when process message");
453 }
454 }
455
456 if !tracing::Span::current().is_none() {
459 tracing::info!(
460 target: PGWIRE_QUERY_LOG,
461 status = if result.is_ok() { "ok" } else { "err" },
462 time = %format_args!("{}ms", elapsed.as_millis()),
463 );
464 }
465
466 result
467 };
468
469 let fut = fut.instrument(span);
471
472 match fut.await {
474 Ok(()) => Some(()),
475 Err(e) => {
476 match e {
477 PsqlError::IoError(io_err) => {
478 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
479 return None;
480 }
481 }
482
483 PsqlError::SslError(_) => {
484 return None;
487 }
488
489 PsqlError::StartupError(_) | PsqlError::PasswordError => {
490 self.stream
491 .write_no_flush(BeMessage::ErrorResponse {
492 error: &e,
493 pretty: false,
496 severity: Some(Severity::Fatal),
497 })
498 .ok()?;
499 let _ = self.stream.flush().await;
500 return None;
501 }
502
503 PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
504 self.stream
505 .write_no_flush(BeMessage::ErrorResponse {
506 error: &e,
507 pretty: true,
508 severity: None,
509 })
510 .ok()?;
511 self.ready_for_query().ok()?;
512 }
513
514 PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
515 self.stream
516 .write_no_flush(BeMessage::ErrorResponse {
517 error: &e,
518 pretty: true,
519 severity: None,
520 })
521 .ok()?;
522 let _ = self.stream.flush().await;
523
524 return None;
529 }
530
531 PsqlError::Uncategorized(_)
532 | PsqlError::ExtendedPrepareError(_)
533 | PsqlError::ExtendedExecuteError(_) => {
534 self.stream
535 .write_no_flush(BeMessage::ErrorResponse {
536 error: &e,
537 pretty: true,
538 severity: None,
539 })
540 .ok()?;
541 }
542 }
543 let _ = self.stream.flush().await;
544 Some(())
545 }
546 }
547 }
548
549 async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
550 if self.ignore_util_sync {
552 if let FeMessage::Sync = msg {
553 } else {
554 tracing::trace!("ignore message {:?} until sync.", msg);
555 return Ok(());
556 }
557 }
558
559 match msg {
560 FeMessage::Gss => self.process_gss_msg().await?,
561 FeMessage::Ssl => self.process_ssl_msg().await?,
562 FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
563 FeMessage::Password(msg) => self.process_password_msg(msg).await?,
564 FeMessage::Query(query_msg) => {
565 let sql = Arc::from(query_msg.get_sql()?);
566 drop(query_msg);
568 self.process_query_msg(sql).await?
569 }
570 FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
571 FeMessage::Terminate => self.process_terminate(),
572 FeMessage::Parse(m) => {
573 if let Err(err) = self.process_parse_msg(m).await {
574 self.ignore_util_sync = true;
575 return Err(err);
576 }
577 }
578 FeMessage::Bind(m) => {
579 if let Err(err) = self.process_bind_msg(m) {
580 self.ignore_util_sync = true;
581 return Err(err);
582 }
583 }
584 FeMessage::Execute(m) => {
585 if let Err(err) = self.process_execute_msg(m).await {
586 self.ignore_util_sync = true;
587 return Err(err);
588 }
589 }
590 FeMessage::Describe(m) => {
591 if let Err(err) = self.process_describe_msg(m) {
592 self.ignore_util_sync = true;
593 return Err(err);
594 }
595 }
596 FeMessage::Sync => {
597 self.ignore_util_sync = false;
598 self.ready_for_query()?
599 }
600 FeMessage::Close(m) => {
601 if let Err(err) = self.process_close_msg(m) {
602 self.ignore_util_sync = true;
603 return Err(err);
604 }
605 }
606 FeMessage::Flush => {
607 if let Err(err) = self.stream.flush().await {
608 self.ignore_util_sync = true;
609 return Err(err.into());
610 }
611 }
612 FeMessage::HealthCheck => self.process_health_check(),
613 FeMessage::ServerThrottle(reason) => match reason {
614 ServerThrottleReason::TooLargeMessage => {
615 return Err(PsqlError::ServerThrottle(format!(
616 "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
617 self.message_memory_manager.max_filter_bytes
618 )));
619 }
620 ServerThrottleReason::TooManyMemoryUsage => {
621 return Err(PsqlError::ServerThrottle(format!(
622 "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
623 self.message_memory_manager.max_running_bytes
624 )));
625 }
626 },
627 }
628 self.stream.flush().await?;
629 Ok(())
630 }
631
632 pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
633 match self.state {
634 PgProtocolState::Startup => self
635 .stream
636 .read_startup()
637 .await
638 .map(|message: FeMessage| (message, None)),
639 PgProtocolState::Regular => {
640 self.stream.read_header().await?;
641 let guard = if let Some(ref header) = self.stream.read_header {
642 let payload_len = std::cmp::max(header.payload_len, 0) as u64;
643 let (reason, guard) = self.message_memory_manager.add(payload_len);
644 if let Some(reason) = reason {
645 drop(guard);
647 self.stream.skip_body().await?;
648 return Ok((FeMessage::ServerThrottle(reason), None));
649 }
650 guard
651 } else {
652 None
653 };
654 let message = self.stream.read_body().await?;
655 Ok((message, guard))
656 }
657 }
658 }
659
660 fn ready_for_query(&mut self) -> io::Result<()> {
662 self.stream.write_no_flush(BeMessage::ReadyForQuery(
663 self.session
664 .as_ref()
665 .map(|s| s.transaction_status())
666 .unwrap_or(TransactionStatus::Idle),
667 ))
668 }
669
670 async fn process_gss_msg(&mut self) -> PsqlResult<()> {
671 self.stream.write(BeMessage::EncryptionResponseNo).await?;
673 Ok(())
674 }
675
676 async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
677 if let Some(context) = self.tls_context.as_ref() {
678 self.stream.write(BeMessage::EncryptionResponseSsl).await?;
681 self.stream.upgrade_to_ssl(context).await?;
682 } else {
683 self.stream.write(BeMessage::EncryptionResponseNo).await?;
685 }
686
687 Ok(())
688 }
689
690 async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
691 if let Some(ref tls_config) = self.tls_config
693 && tls_config.enforce_ssl
694 && !self.stream.is_ssl_connection().await
695 {
696 return Err(PsqlError::StartupError(
697 "SSL connection is required but not established".into(),
698 ));
699 }
700
701 let db_name = msg
702 .config
703 .get("database")
704 .cloned()
705 .unwrap_or_else(|| "dev".to_owned());
706 let user_name = msg
707 .config
708 .get("user")
709 .cloned()
710 .unwrap_or_else(|| "root".to_owned());
711
712 let session = self
713 .session_mgr
714 .connect(&db_name, &user_name, self.peer_addr.clone())
715 .map_err(|e| PsqlError::StartupError(e.into()))?;
716
717 if let Some(options) = msg.config.get("options") {
718 for (key, value) in parse_options(options)? {
719 session
720 .set_config(&key, value)
721 .map_err(|e| PsqlError::StartupError(e.into()))?;
722 }
723 }
724 let application_name = msg.config.get("application_name");
726 if let Some(application_name) = application_name {
727 session
728 .set_config("application_name", application_name.clone())
729 .map_err(|e| PsqlError::StartupError(e.into()))?;
730 }
731
732 match session.user_authenticator() {
733 UserAuthenticator::None => {
734 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
735
736 self.stream
739 .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
740
741 self.stream.write_no_flush(BeMessage::ParameterStatus(
742 BeParameterStatusMessage::TimeZone(
743 &session
744 .get_config("timezone")
745 .map_err(|e| PsqlError::StartupError(e.into()))?,
746 ),
747 ))?;
748 self.stream
749 .write_parameter_status_msg_no_flush(&ParameterStatus {
750 application_name: application_name.cloned(),
751 })?;
752 self.ready_for_query()?;
753 }
754 UserAuthenticator::ClearText(_)
755 | UserAuthenticator::OAuth { .. }
756 | UserAuthenticator::Ldap(..) => {
757 self.stream
758 .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
759 }
760 UserAuthenticator::Md5WithSalt { salt, .. } => {
761 self.stream
762 .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
763 }
764 }
765
766 self.session = Some(session);
767 self.state = PgProtocolState::Regular;
768 Ok(())
769 }
770
771 async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
772 let session = self.session.as_ref().unwrap();
773 let authenticator = session.user_authenticator();
774 authenticator.authenticate(&msg.password).await?;
775 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
776 let timezone = session
777 .get_config("timezone")
778 .map_err(|e| PsqlError::StartupError(e.into()))?;
779 self.stream.write_no_flush(BeMessage::ParameterStatus(
780 BeParameterStatusMessage::TimeZone(&timezone),
781 ))?;
782 self.stream
783 .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
784 self.ready_for_query()?;
785 self.state = PgProtocolState::Regular;
786 Ok(())
787 }
788
789 fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
790 let session_id = (m.target_process_id, m.target_secret_key);
791 tracing::trace!("cancel query in session: {:?}", session_id);
792 self.session_mgr.cancel_queries_in_session(session_id);
793 self.session_mgr.cancel_creating_jobs_in_session(session_id);
794 self.is_terminate = true;
795 Ok(())
796 }
797
798 async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
799 let truncated_sql =
800 get_redacted_and_truncated_sql(&sql, self.redact_sql_option_keywords.clone());
801 let session = self.session.clone().unwrap();
802
803 session.check_idle_in_transaction_timeout()?;
804 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
806 self.inner_process_query_msg(sql, session.clone()).await
807 }
808
809 async fn inner_process_query_msg(
810 &mut self,
811 sql: Arc<str>,
812 session: Arc<SM::Session>,
813 ) -> PsqlResult<()> {
814 let stmts =
816 Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
817 drop(sql);
819 if stmts.is_empty() {
820 self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
821 }
822
823 for stmt in stmts {
825 self.inner_process_query_msg_one_stmt(stmt, session.clone())
826 .await?;
827 }
828 self.ready_for_query()?;
831 Ok(())
832 }
833
834 async fn inner_process_query_msg_one_stmt(
835 &mut self,
836 stmt: Statement,
837 session: Arc<SM::Session>,
838 ) -> PsqlResult<()> {
839 let session = session.clone();
840
841 let res = session.clone().run_one_query(stmt, Format::Text).await;
843
844 while let Some(notice) = session.next_notice().now_or_never() {
846 self.stream
847 .write_no_flush(BeMessage::NoticeResponse(¬ice))?;
848 }
849
850 let mut res = res.map_err(|e| PsqlError::SimpleQueryError(e.into()))?;
851
852 for notice in res.notices() {
853 self.stream
854 .write_no_flush(BeMessage::NoticeResponse(notice))?;
855 }
856
857 let status = res.status();
858 if let Some(ref application_name) = status.application_name {
859 self.stream.write_no_flush(BeMessage::ParameterStatus(
860 BeParameterStatusMessage::ApplicationName(application_name),
861 ))?;
862 }
863
864 if res.is_copy_query_to_stdout() {
865 self.stream
866 .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
867 let mut count = 0;
868 while let Some(row_set) = res.values_stream().next().await {
869 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
870 for row in row_set {
871 self.stream.write_no_flush(BeMessage::CopyData(&row))?;
872 count += 1;
873 }
874 }
875
876 self.stream.write_no_flush(BeMessage::CopyDone)?;
877
878 res.run_callback().await?;
880
881 self.stream
882 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
883 stmt_type: res.stmt_type(),
884 rows_cnt: count,
885 }))?;
886 } else if res.is_query() {
887 self.stream
888 .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
889
890 let mut rows_cnt = 0;
891
892 while let Some(row_set) = res.values_stream().next().await {
893 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
894 for row in row_set {
895 self.stream.write_no_flush(BeMessage::DataRow(&row))?;
896 rows_cnt += 1;
897 }
898 }
899
900 res.run_callback().await?;
902
903 self.stream
904 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
905 stmt_type: res.stmt_type(),
906 rows_cnt,
907 }))?;
908 } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
909 let first_row_set = res.values_stream().next().await;
910 let first_row_set = match first_row_set {
911 None => {
912 return Err(PsqlError::Uncategorized(
913 anyhow::anyhow!("no affected rows in output").into(),
914 ));
915 }
916 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
917 };
918 let affected_rows_str = first_row_set[0].values()[0]
919 .as_ref()
920 .expect("compute node should return affected rows in output");
921
922 assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
923 let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
924 .unwrap()
925 .parse()
926 .unwrap_or_default();
927
928 res.run_callback().await?;
930
931 self.stream
932 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
933 stmt_type: res.stmt_type(),
934 rows_cnt: affected_rows_cnt,
935 }))?;
936 } else {
937 res.run_callback().await?;
939
940 self.stream
941 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
942 stmt_type: res.stmt_type(),
943 rows_cnt: 0,
944 }))?;
945 }
946
947 Ok(())
948 }
949
950 fn process_terminate(&mut self) {
951 self.is_terminate = true;
952 }
953
954 fn process_health_check(&mut self) {
955 tracing::debug!("health check");
956 self.is_terminate = true;
957 }
958
959 async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
960 let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
961 let session = self.session.clone().unwrap();
962 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
963 let type_ids = std::mem::take(&mut msg.type_ids);
964 drop(msg);
966 self.inner_process_parse_msg(session, sql, statement_name, type_ids)
967 .await?;
968 Ok(())
969 }
970
971 async fn inner_process_parse_msg(
972 &mut self,
973 session: Arc<SM::Session>,
974 sql: Arc<str>,
975 statement_name: String,
976 type_ids: Vec<i32>,
977 ) -> PsqlResult<()> {
978 if statement_name.is_empty() {
979 self.unnamed_prepare_statement.take();
982 } else if self.prepare_statement_store.contains_key(&statement_name) {
983 return Err(PsqlError::ExtendedPrepareError(
984 "Duplicated statement name".into(),
985 ));
986 }
987
988 let stmt = {
989 let stmts = Parser::parse_sql(&sql)
990 .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
991 drop(sql);
992 if stmts.len() > 1 {
993 return Err(PsqlError::ExtendedPrepareError(
994 "Only one statement is allowed in extended query mode".into(),
995 ));
996 }
997
998 stmts.into_iter().next()
999 };
1000
1001 let param_types: Vec<Option<DataType>> = type_ids
1002 .iter()
1003 .map(|&id| {
1004 if id == 0 {
1007 Ok(None)
1008 } else {
1009 DataType::from_oid(id)
1010 .map(Some)
1011 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
1012 }
1013 })
1014 .try_collect()?;
1015
1016 let prepare_statement = session
1017 .parse(stmt, param_types)
1018 .await
1019 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))?;
1020
1021 if statement_name.is_empty() {
1022 self.unnamed_prepare_statement.replace(prepare_statement);
1023 } else {
1024 self.prepare_statement_store
1025 .insert(statement_name.clone(), prepare_statement);
1026 }
1027
1028 self.statement_portal_dependency
1029 .entry(statement_name)
1030 .or_default()
1031 .clear();
1032
1033 self.stream.write_no_flush(BeMessage::ParseComplete)?;
1034 Ok(())
1035 }
1036
1037 fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
1038 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
1039 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1040 let session = self.session.clone().unwrap();
1041
1042 if self.portal_store.contains_key(&portal_name) {
1043 return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
1044 }
1045
1046 let prepare_statement = self.get_statement(&statement_name)?;
1047
1048 let result_formats = msg
1049 .result_format_codes
1050 .iter()
1051 .map(|&format_code| Format::from_i16(format_code))
1052 .try_collect()?;
1053 let param_formats = msg
1054 .param_format_codes
1055 .iter()
1056 .map(|&format_code| Format::from_i16(format_code))
1057 .try_collect()?;
1058
1059 let portal = session
1060 .bind(prepare_statement, msg.params, param_formats, result_formats)
1061 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1062
1063 if portal_name.is_empty() {
1064 self.result_cache.remove(&portal_name);
1065 self.unnamed_portal.replace(portal);
1066 } else {
1067 assert!(
1068 !self.result_cache.contains_key(&portal_name),
1069 "Named portal never can be overridden."
1070 );
1071 self.portal_store.insert(portal_name.clone(), portal);
1072 }
1073
1074 self.statement_portal_dependency
1075 .get_mut(&statement_name)
1076 .unwrap()
1077 .push(portal_name);
1078
1079 self.stream.write_no_flush(BeMessage::BindComplete)?;
1080 Ok(())
1081 }
1082
1083 async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1084 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1085 let row_max = msg.max_rows as usize;
1086 drop(msg);
1087 let session = self.session.clone().unwrap();
1088
1089 match self.result_cache.remove(&portal_name) {
1090 Some(mut result_cache) => {
1091 assert!(self.portal_store.contains_key(&portal_name));
1092
1093 let is_consume_completed =
1094 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1095
1096 if !is_consume_completed {
1097 self.result_cache.insert(portal_name, result_cache);
1098 }
1099 }
1100 _ => {
1101 let portal = self.get_portal(&portal_name)?;
1102 let sql = format!("{}", portal);
1103 let truncated_sql =
1104 get_redacted_and_truncated_sql(&sql, self.redact_sql_option_keywords.clone());
1105 drop(sql);
1106
1107 session.check_idle_in_transaction_timeout()?;
1108 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1110 let result = session.clone().execute(portal).await;
1111
1112 let pg_response = result.map_err(|e| PsqlError::ExtendedExecuteError(e.into()))?;
1113 let mut result_cache = ResultCache::new(pg_response);
1114 let is_consume_completed =
1115 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1116 if !is_consume_completed {
1117 self.result_cache.insert(portal_name, result_cache);
1118 }
1119 }
1120 }
1121
1122 Ok(())
1123 }
1124
1125 fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1126 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1127 let session = self.session.clone().unwrap();
1128 assert!(msg.kind == b'S' || msg.kind == b'P');
1132 if msg.kind == b'S' {
1133 let prepare_statement = self.get_statement(&name)?;
1134
1135 let (param_types, row_descriptions) = self
1136 .session
1137 .clone()
1138 .unwrap()
1139 .describe_statement(prepare_statement)
1140 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1141 self.stream.write_no_flush(BeMessage::ParameterDescription(
1142 ¶m_types.iter().map(|t| t.to_oid()).collect_vec(),
1143 ))?;
1144
1145 if row_descriptions.is_empty() {
1146 self.stream.write_no_flush(BeMessage::NoData)?;
1149 } else {
1150 self.stream
1151 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1152 }
1153 } else if msg.kind == b'P' {
1154 let portal = self.get_portal(&name)?;
1155
1156 let row_descriptions = session
1157 .describe_portal(portal)
1158 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1159
1160 if row_descriptions.is_empty() {
1161 self.stream.write_no_flush(BeMessage::NoData)?;
1164 } else {
1165 self.stream
1166 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1167 }
1168 }
1169 Ok(())
1170 }
1171
1172 fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1173 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1174 assert!(msg.kind == b'S' || msg.kind == b'P');
1175 if msg.kind == b'S' {
1176 if name.is_empty() {
1177 self.unnamed_prepare_statement = None;
1178 } else {
1179 self.prepare_statement_store.remove(&name);
1180 }
1181 for portal_name in self
1182 .statement_portal_dependency
1183 .remove(&name)
1184 .unwrap_or_default()
1185 {
1186 self.remove_portal(&portal_name);
1187 }
1188 } else if msg.kind == b'P' {
1189 self.remove_portal(&name);
1190 }
1191 self.stream.write_no_flush(BeMessage::CloseComplete)?;
1192 Ok(())
1193 }
1194
1195 fn remove_portal(&mut self, portal_name: &str) {
1196 if portal_name.is_empty() {
1197 self.unnamed_portal = None;
1198 } else {
1199 self.portal_store.remove(portal_name);
1200 }
1201 self.result_cache.remove(portal_name);
1202 }
1203
1204 fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1205 if portal_name.is_empty() {
1206 Ok(self
1207 .unnamed_portal
1208 .as_ref()
1209 .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1210 .clone())
1211 } else {
1212 Ok(self
1213 .portal_store
1214 .get(portal_name)
1215 .ok_or_else(|| {
1216 PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1217 })?
1218 .clone())
1219 }
1220 }
1221
1222 fn get_statement(
1223 &self,
1224 statement_name: &str,
1225 ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1226 if statement_name.is_empty() {
1227 Ok(self
1228 .unnamed_prepare_statement
1229 .as_ref()
1230 .ok_or_else(|| {
1231 PsqlError::Uncategorized("unnamed prepare statement not found".into())
1232 })?
1233 .clone())
1234 } else {
1235 Ok(self
1236 .prepare_statement_store
1237 .get(statement_name)
1238 .ok_or_else(|| {
1239 PsqlError::Uncategorized(
1240 format!("Prepare statement {} not found", statement_name).into(),
1241 )
1242 })?
1243 .clone())
1244 }
1245 }
1246}
1247
1248enum PgStreamInner<S> {
1249 Placeholder,
1251 Unencrypted(S),
1253 Ssl(SslStream<S>),
1255}
1256
1257pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1259impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1260
1261pub struct PgStream<S> {
1266 stream: Arc<Mutex<PgStreamInner<S>>>,
1268 write_buf: BytesMut,
1270 read_header: Option<FeMessageHeader>,
1271}
1272
1273impl<S> PgStream<S> {
1274 pub fn new(stream: S) -> Self {
1276 const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1277
1278 Self {
1279 stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1280 write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1281 read_header: None,
1282 }
1283 }
1284
1285 async fn is_ssl_connection(&self) -> bool {
1287 let stream = self.stream.lock().await;
1288 matches!(*stream, PgStreamInner::Ssl(_))
1289 }
1290}
1291
1292impl<S> Clone for PgStream<S> {
1293 fn clone(&self) -> Self {
1294 Self {
1295 stream: Arc::clone(&self.stream),
1296 write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1297 read_header: self.read_header.clone(),
1298 }
1299 }
1300}
1301
1302#[derive(Debug, Default, Clone)]
1319pub struct ParameterStatus {
1320 pub application_name: Option<String>,
1321}
1322
1323impl<S> PgStream<S>
1324where
1325 S: PgByteStream,
1326{
1327 async fn read_startup(&mut self) -> io::Result<FeMessage> {
1328 let mut stream = self.stream.lock().await;
1329 match &mut *stream {
1330 PgStreamInner::Placeholder => unreachable!(),
1331 PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1332 PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1333 }
1334 }
1335
1336 async fn read_header(&mut self) -> io::Result<()> {
1337 let mut stream = self.stream.lock().await;
1338 match &mut *stream {
1339 PgStreamInner::Placeholder => unreachable!(),
1340 PgStreamInner::Unencrypted(stream) => {
1341 self.read_header = Some(FeMessage::read_header(stream).await?);
1342 Ok(())
1343 }
1344 PgStreamInner::Ssl(ssl_stream) => {
1345 self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1346 Ok(())
1347 }
1348 }
1349 }
1350
1351 async fn read_body(&mut self) -> io::Result<FeMessage> {
1352 let mut stream = self.stream.lock().await;
1353 let header = self
1354 .read_header
1355 .take()
1356 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1357 match &mut *stream {
1358 PgStreamInner::Placeholder => unreachable!(),
1359 PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1360 PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1361 }
1362 }
1363
1364 async fn skip_body(&mut self) -> io::Result<()> {
1365 let mut stream = self.stream.lock().await;
1366 let header = self
1367 .read_header
1368 .take()
1369 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1370 match &mut *stream {
1371 PgStreamInner::Placeholder => unreachable!(),
1372 PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1373 PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1374 }
1375 }
1376
1377 fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1378 self.write_no_flush(BeMessage::ParameterStatus(
1379 BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1380 ))?;
1381 self.write_no_flush(BeMessage::ParameterStatus(
1382 BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1383 ))?;
1384 self.write_no_flush(BeMessage::ParameterStatus(
1385 BeParameterStatusMessage::ServerVersion(PG_VERSION),
1386 ))?;
1387 if let Some(application_name) = &status.application_name {
1388 self.write_no_flush(BeMessage::ParameterStatus(
1389 BeParameterStatusMessage::ApplicationName(application_name),
1390 ))?;
1391 }
1392 Ok(())
1393 }
1394
1395 pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1396 BeMessage::write(&mut self.write_buf, message)
1397 }
1398
1399 async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1400 self.write_no_flush(message)?;
1401 self.flush().await?;
1402 Ok(())
1403 }
1404
1405 async fn flush(&mut self) -> io::Result<()> {
1406 let mut stream = self.stream.lock().await;
1407 match &mut *stream {
1408 PgStreamInner::Placeholder => unreachable!(),
1409 PgStreamInner::Unencrypted(stream) => {
1410 stream.write_all(&self.write_buf).await?;
1411 stream.flush().await?;
1412 }
1413 PgStreamInner::Ssl(ssl_stream) => {
1414 ssl_stream.write_all(&self.write_buf).await?;
1415 ssl_stream.flush().await?;
1416 }
1417 }
1418 self.write_buf.clear();
1419 Ok(())
1420 }
1421}
1422
1423impl<S> PgStream<S>
1424where
1425 S: PgByteStream,
1426{
1427 async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1429 let mut stream = self.stream.lock().await;
1430
1431 match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1432 PgStreamInner::Unencrypted(unencrypted_stream) => {
1433 let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1434 let mut ssl_stream =
1435 tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1436
1437 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1438 tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1439 let _ = ssl_stream.shutdown().await;
1440 return Err(e.into());
1441 }
1442
1443 *stream = PgStreamInner::Ssl(ssl_stream);
1444 }
1445 PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1446 PgStreamInner::Placeholder => unreachable!(),
1447 }
1448
1449 Ok(())
1450 }
1451}
1452
1453fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1454 let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1455
1456 let key_path = &tls_config.key;
1457 let cert_path = &tls_config.cert;
1458
1459 acceptor
1462 .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1463 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1464 acceptor
1465 .set_ca_file(cert_path)
1466 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1467 acceptor
1468 .set_certificate_chain_file(cert_path)
1469 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1470 let acceptor = acceptor.build();
1471
1472 Ok(acceptor.into_context())
1473}
1474
1475pub mod truncated_fmt {
1476 use std::fmt::*;
1477
1478 struct TruncatedFormatter<'a, 'b> {
1479 remaining: usize,
1480 finished: bool,
1481 f: &'a mut Formatter<'b>,
1482 }
1483 impl Write for TruncatedFormatter<'_, '_> {
1484 fn write_str(&mut self, s: &str) -> Result {
1485 if self.finished {
1486 return Ok(());
1487 }
1488
1489 if self.remaining < s.len() {
1490 let actual = s.floor_char_boundary(self.remaining);
1491 self.f.write_str(&s[0..actual])?;
1492 self.remaining -= actual;
1493 self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1494 self.finished = true; } else {
1496 self.f.write_str(s)?;
1497 self.remaining -= s.len();
1498 }
1499 Ok(())
1500 }
1501 }
1502
1503 pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1504
1505 impl<T> Debug for TruncatedFmt<'_, T>
1506 where
1507 T: Debug,
1508 {
1509 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1510 TruncatedFormatter {
1511 remaining: self.1,
1512 finished: false,
1513 f,
1514 }
1515 .write_fmt(format_args!("{:?}", self.0))
1516 }
1517 }
1518
1519 impl<T> Display for TruncatedFmt<'_, T>
1520 where
1521 T: Display,
1522 {
1523 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1524 TruncatedFormatter {
1525 remaining: self.1,
1526 finished: false,
1527 f,
1528 }
1529 .write_fmt(format_args!("{}", self.0))
1530 }
1531 }
1532
1533 #[cfg(test)]
1534 mod tests {
1535 use super::*;
1536
1537 #[test]
1538 fn test_trunc_utf8() {
1539 assert_eq!(
1540 format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1541 "select '...(truncated,14)",
1542 );
1543 }
1544 }
1545}
1546
1547fn parse_options(options: &str) -> PsqlResult<Vec<(String, String)>> {
1560 let mut args = Vec::new();
1561 let mut current_arg = String::new();
1562 let mut chars = options.chars().peekable();
1563
1564 while let Some(c) = chars.next() {
1565 if c == '\\' {
1566 if let Some(next_c) = chars.next() {
1567 current_arg.push(next_c);
1568 }
1569 } else if c.is_ascii_whitespace() {
1570 if !current_arg.is_empty() {
1571 args.push(std::mem::take(&mut current_arg));
1572 }
1573 } else {
1574 current_arg.push(c);
1575 }
1576 }
1577 if !current_arg.is_empty() {
1578 args.push(current_arg);
1579 }
1580
1581 let mut args_iter = args.into_iter();
1582 let mut config = Vec::new();
1583
1584 while let Some(arg) = args_iter.next() {
1585 if arg == "-c" {
1586 if let Some(config_str) = args_iter.next() {
1587 if let Some((key, value)) = config_str.split_once('=') {
1588 let key = key.replace("-", "_");
1589 config.push((key, value.to_owned()));
1590 } else {
1591 return Err(PsqlError::StartupError(
1592 format!("invalid config format: {}", config_str).into(),
1593 ));
1594 }
1595 } else {
1596 return Err(PsqlError::StartupError("missing argument for -c".into()));
1597 }
1598 } else if let Some(config_str) = arg.strip_prefix("--") {
1599 if let Some((key, value)) = config_str.split_once('=') {
1600 let key = key.replace("-", "_");
1601 config.push((key, value.to_owned()));
1602 } else {
1603 return Err(PsqlError::StartupError(
1604 format!("invalid config format: {}", config_str).into(),
1605 ));
1606 }
1607 } else {
1608 tracing::warn!(
1609 arg,
1610 "ignoring unrecognized option for backward compatibility"
1611 );
1612 }
1613 }
1614 Ok(config)
1615}
1616
1617#[cfg(test)]
1618mod tests {
1619 use std::collections::HashSet;
1620
1621 use super::*;
1622
1623 #[test]
1624 fn test_redact_parsable_sql() {
1625 let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1626 let sql = r"
1627 create source temp (k bigint, v varchar) with (
1628 connector = 'datagen',
1629 v1 = 123,
1630 v2 = 'with',
1631 v3 = false,
1632 v4 = '',
1633 ) FORMAT plain ENCODE json (a='1',b='2')
1634 ";
1635 assert_eq!(
1636 redact_sql(sql, keywords),
1637 "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1638 );
1639 }
1640
1641 #[test]
1642 fn test_parse_options() {
1643 assert_eq!(parse_options("").unwrap(), vec![]);
1644 assert_eq!(
1645 parse_options("-c a=1 -c b=2").unwrap(),
1646 vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1647 );
1648 assert_eq!(
1649 parse_options("-c key=value").unwrap(),
1650 vec![("key".into(), "value".into())]
1651 );
1652 assert_eq!(
1654 parse_options("-c key='value'").unwrap(),
1655 vec![("key".into(), "'value'".into())]
1656 );
1657
1658 assert_eq!(
1660 parse_options(r#"-c key=value\ with\ spaces"#).unwrap(),
1661 vec![("key".into(), "value with spaces".into())]
1662 );
1663 assert_eq!(
1664 parse_options(r#"-c search_path=my\ schema"#).unwrap(),
1665 vec![("search_path".into(), "my schema".into())]
1666 );
1667
1668 assert!(parse_options("-c").is_err());
1669 assert!(parse_options("-c foo").is_err()); assert!(parse_options("--foo").is_err()); assert_eq!(
1673 parse_options("--foo=bar").unwrap(),
1674 vec![("foo".into(), "bar".into())]
1675 );
1676 assert_eq!(
1677 parse_options(r#"--foo=bar\ baz"#).unwrap(),
1678 vec![("foo".into(), "bar baz".into())]
1679 );
1680 assert_eq!(
1681 parse_options("-c a=1 --b=2").unwrap(),
1682 vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1683 );
1684 assert_eq!(
1686 parse_options(r#"-c a=b\"#).unwrap(),
1687 vec![("a".into(), "b".into())]
1688 );
1689 }
1690}