1use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::error_or_notice::Severity;
46use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
47use crate::net::AddressRef;
48use crate::pg_extended::ResultCache;
49use crate::pg_message::{
50 BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
51 FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
52 FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
53};
54use crate::pg_server::{Session, SessionManager, UserAuthenticator};
55use crate::types::Format;
56
57static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
60 LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
61 Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
62 _ => 65536,
63 });
64
65tokio::task_local! {
66 pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
68}
69
70pub struct PgProtocol<S, SM>
73where
74 SM: SessionManager,
75{
76 stream: PgStream<S>,
78 state: PgProtocolState,
80 is_terminate: bool,
82
83 session_mgr: Arc<SM>,
84 session: Option<Arc<SM::Session>>,
85
86 result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
87 unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
88 prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
89 unnamed_portal: Option<<SM::Session as Session>::Portal>,
90 portal_store: HashMap<String, <SM::Session as Session>::Portal>,
91 statement_portal_dependency: HashMap<String, Vec<String>>,
94
95 tls_context: Option<SslContext>,
98
99 tls_config: Option<TlsConfig>,
101
102 ignore_util_sync: bool,
105
106 peer_addr: AddressRef,
108
109 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
110 message_memory_manager: MessageMemoryManagerRef,
111}
112
113#[derive(Debug, Clone)]
115pub struct TlsConfig {
116 pub cert: String,
118 pub key: String,
120 pub enforce_ssl: bool,
122}
123
124impl TlsConfig {
125 pub fn new_default() -> Option<Self> {
126 let cert = std::env::var("RW_SSL_CERT").ok()?;
127 let key = std::env::var("RW_SSL_KEY").ok()?;
128 let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
129 tracing::info!(
130 "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
131 cert,
132 key,
133 enforce_ssl
134 );
135 Some(Self {
136 cert,
137 key,
138 enforce_ssl,
139 })
140 }
141}
142
143impl<S, SM> Drop for PgProtocol<S, SM>
144where
145 SM: SessionManager,
146{
147 fn drop(&mut self) {
148 if let Some(session) = &self.session {
149 self.session_mgr.end_session(session);
151 }
152 }
153}
154
155enum PgProtocolState {
157 Startup,
158 Regular,
159}
160
161pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
165 let without_null = if b.last() == Some(&0) {
166 &b[..b.len() - 1]
167 } else {
168 &b[..]
169 };
170 std::str::from_utf8(without_null)
171}
172
173fn get_redacted_and_truncated_sql(
174 sql: &str,
175 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
176) -> String {
177 let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
178 && !keywords.is_empty()
179 {
180 redact_sql(sql, keywords)
181 } else {
182 sql.to_owned()
183 };
184 let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
185 truncated.to_string()
186}
187
188fn record_sql_in_span(
190 sql: &str,
191 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
192 span: &mut tracing::Span,
193) {
194 let redacted_and_truncated_sql =
195 get_redacted_and_truncated_sql(sql, redact_sql_option_keywords);
196 span.record("sql", tracing::field::display(&redacted_and_truncated_sql));
197}
198
199fn record_user_in_span(user: &str, span: &mut tracing::Span) {
200 span.record("user", tracing::field::display(user));
201}
202
203fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
205 match Parser::parse_sql(sql) {
206 Ok(sqls) => sqls
207 .into_iter()
208 .map(|sql| sql.to_redacted_string(keywords.clone()))
209 .join(";"),
210 Err(_) => sql.to_owned(),
211 }
212}
213
214#[derive(Clone)]
215pub struct ConnectionContext {
216 pub tls_config: Option<TlsConfig>,
217 pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
218 pub message_memory_manager: MessageMemoryManagerRef,
219}
220
221impl<S, SM> PgProtocol<S, SM>
222where
223 S: PgByteStream,
224 SM: SessionManager,
225{
226 pub fn new(
227 stream: S,
228 session_mgr: Arc<SM>,
229 peer_addr: AddressRef,
230 context: ConnectionContext,
231 ) -> Self {
232 let ConnectionContext {
233 tls_config,
234 redact_sql_option_keywords,
235 message_memory_manager,
236 } = context;
237 Self {
238 stream: PgStream::new(stream),
239 is_terminate: false,
240 state: PgProtocolState::Startup,
241 session_mgr,
242 session: None,
243 tls_context: tls_config
244 .as_ref()
245 .and_then(|e| build_ssl_ctx_from_config(e).ok()),
246 tls_config,
247 result_cache: Default::default(),
248 unnamed_prepare_statement: Default::default(),
249 prepare_statement_store: Default::default(),
250 unnamed_portal: Default::default(),
251 portal_store: Default::default(),
252 statement_portal_dependency: Default::default(),
253 ignore_util_sync: false,
254 peer_addr,
255 redact_sql_option_keywords,
256 message_memory_manager,
257 }
258 }
259
260 pub async fn run(&mut self) {
262 let mut notice_fut = None;
263
264 loop {
265 if notice_fut.is_none()
267 && let Some(session) = self.session.clone()
268 {
269 let mut stream = self.stream.clone();
270 notice_fut = Some(Box::pin(async move {
271 loop {
272 let notice = session.next_notice().await;
273 if let Err(e) = stream.write(BeMessage::NoticeResponse(¬ice)).await {
274 tracing::error!(error = %e.as_report(), notice, "failed to send notice");
275 }
276 }
277 }));
278 }
279
280 let process = std::pin::pin!(async {
282 let (msg, _memory_guard) = match self.read_message().await {
283 Ok(msg) => msg,
284 Err(e) => {
285 tracing::error!(error = %e.as_report(), "error when reading message");
286 return true; }
288 };
289 tracing::trace!(?msg, "received message");
290 self.process(msg).await
291 });
292
293 let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
294 tokio::select! {
295 _ = notice_fut => unreachable!(),
296 terminated = process => terminated,
297 }
298 } else {
299 process.await
300 };
301
302 if terminated {
303 break;
304 }
305 }
306 }
307
308 pub async fn process(&mut self, msg: FeMessage) -> bool {
310 self.do_process(msg).await.is_none() || self.is_terminate
311 }
312
313 fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
321 let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
322 return tracing::Span::none();
323 };
324
325 let mode = match msg {
326 FeMessage::Query(_) => "simple query",
327 FeMessage::Parse(_) => "extended query parse",
328 FeMessage::Execute(_) => "extended query execute",
329 _ => return tracing::Span::none(),
330 };
331
332 let mut span = tracing::info_span!(
333 target: PGWIRE_ROOT_SPAN_TARGET,
334 "handle_query",
335 mode,
336 session_id,
337 sql = tracing::field::Empty,
338 user = tracing::field::Empty,
339 );
340 if let Ok(sql) = msg.get_sql()
341 && let Some(sql) = sql
342 {
343 record_sql_in_span(sql, self.redact_sql_option_keywords.clone(), &mut span);
344 }
345 if let Some(current_session) = self.session.as_ref() {
346 record_user_in_span(¤t_session.user(), &mut span);
347 }
348 span
349 }
350
351 async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
355 let span = self.root_span_for_msg(&msg);
356 let weak_session = self
357 .session
358 .as_ref()
359 .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
360
361 let fut = Box::pin(self.do_process_inner(msg));
366
367 let fut = async move {
369 if let Some(session) = weak_session {
370 CURRENT_SESSION.scope(session, fut).await
371 } else {
372 fut.await
373 }
374 };
375
376 let fut = async move {
378 AssertUnwindSafe(fut)
379 .rw_catch_unwind()
380 .await
381 .unwrap_or_else(|payload| {
382 Err(PsqlError::Panic(
383 panic_message::panic_message(&payload).to_owned(),
384 ))
385 })
386 };
387
388 let fut = async move {
390 let period = *SLOW_QUERY_LOG_PERIOD;
391 let mut fut = std::pin::pin!(fut);
392 let mut elapsed = Duration::ZERO;
393
394 loop {
396 match tokio::time::timeout(period, &mut fut).await {
397 Ok(result) => break result,
398 Err(_) => {
399 elapsed += period;
400 tracing::info!(
401 target: PGWIRE_SLOW_QUERY_LOG,
402 elapsed = %format_args!("{}ms", elapsed.as_millis()),
403 "slow query"
404 );
405 }
406 }
407 }
408 };
409
410 let fut = async move {
412 if !tracing::Span::current().is_none() {
413 tracing::info!(
414 target: PGWIRE_QUERY_LOG,
415 status = "started",
416 );
417 }
418
419 let start = Instant::now();
420 let result = fut.await;
421 let elapsed = start.elapsed();
422
423 if let Err(error) = &result {
427 if cfg!(debug_assertions) && !Deployment::current().is_ci() {
428 tracing::error!(error = ?error.as_report(), "error when process message");
434 } else {
435 tracing::error!(error = %error.as_report(), "error when process message");
436 }
437 }
438
439 if !tracing::Span::current().is_none() {
442 tracing::info!(
443 target: PGWIRE_QUERY_LOG,
444 status = if result.is_ok() { "ok" } else { "err" },
445 time = %format_args!("{}ms", elapsed.as_millis()),
446 );
447 }
448
449 result
450 };
451
452 let fut = fut.instrument(span);
454
455 match fut.await {
457 Ok(()) => Some(()),
458 Err(e) => {
459 match e {
460 PsqlError::IoError(io_err) => {
461 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
462 return None;
463 }
464 }
465
466 PsqlError::SslError(_) => {
467 return None;
470 }
471
472 PsqlError::StartupError(_) | PsqlError::PasswordError => {
473 self.stream
474 .write_no_flush(BeMessage::ErrorResponse {
475 error: &e,
476 pretty: false,
479 severity: Some(Severity::Fatal),
480 })
481 .ok()?;
482 let _ = self.stream.flush().await;
483 return None;
484 }
485
486 PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
487 self.stream
488 .write_no_flush(BeMessage::ErrorResponse {
489 error: &e,
490 pretty: true,
491 severity: None,
492 })
493 .ok()?;
494 self.ready_for_query().ok()?;
495 }
496
497 PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
498 self.stream
499 .write_no_flush(BeMessage::ErrorResponse {
500 error: &e,
501 pretty: true,
502 severity: None,
503 })
504 .ok()?;
505 let _ = self.stream.flush().await;
506
507 return None;
512 }
513
514 PsqlError::Uncategorized(_)
515 | PsqlError::ExtendedPrepareError(_)
516 | PsqlError::ExtendedExecuteError(_) => {
517 self.stream
518 .write_no_flush(BeMessage::ErrorResponse {
519 error: &e,
520 pretty: true,
521 severity: None,
522 })
523 .ok()?;
524 }
525 }
526 let _ = self.stream.flush().await;
527 Some(())
528 }
529 }
530 }
531
532 async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
533 if self.ignore_util_sync {
535 if let FeMessage::Sync = msg {
536 } else {
537 tracing::trace!("ignore message {:?} until sync.", msg);
538 return Ok(());
539 }
540 }
541
542 match msg {
543 FeMessage::Gss => self.process_gss_msg().await?,
544 FeMessage::Ssl => self.process_ssl_msg().await?,
545 FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
546 FeMessage::Password(msg) => self.process_password_msg(msg).await?,
547 FeMessage::Query(query_msg) => {
548 let sql = Arc::from(query_msg.get_sql()?);
549 drop(query_msg);
551 self.process_query_msg(sql).await?
552 }
553 FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
554 FeMessage::Terminate => self.process_terminate(),
555 FeMessage::Parse(m) => {
556 if let Err(err) = self.process_parse_msg(m).await {
557 self.ignore_util_sync = true;
558 return Err(err);
559 }
560 }
561 FeMessage::Bind(m) => {
562 if let Err(err) = self.process_bind_msg(m) {
563 self.ignore_util_sync = true;
564 return Err(err);
565 }
566 }
567 FeMessage::Execute(m) => {
568 if let Err(err) = self.process_execute_msg(m).await {
569 self.ignore_util_sync = true;
570 return Err(err);
571 }
572 }
573 FeMessage::Describe(m) => {
574 if let Err(err) = self.process_describe_msg(m) {
575 self.ignore_util_sync = true;
576 return Err(err);
577 }
578 }
579 FeMessage::Sync => {
580 self.ignore_util_sync = false;
581 self.ready_for_query()?
582 }
583 FeMessage::Close(m) => {
584 if let Err(err) = self.process_close_msg(m) {
585 self.ignore_util_sync = true;
586 return Err(err);
587 }
588 }
589 FeMessage::Flush => {
590 if let Err(err) = self.stream.flush().await {
591 self.ignore_util_sync = true;
592 return Err(err.into());
593 }
594 }
595 FeMessage::HealthCheck => self.process_health_check(),
596 FeMessage::ServerThrottle(reason) => match reason {
597 ServerThrottleReason::TooLargeMessage => {
598 return Err(PsqlError::ServerThrottle(format!(
599 "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
600 self.message_memory_manager.max_filter_bytes
601 )));
602 }
603 ServerThrottleReason::TooManyMemoryUsage => {
604 return Err(PsqlError::ServerThrottle(format!(
605 "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
606 self.message_memory_manager.max_running_bytes
607 )));
608 }
609 },
610 }
611 self.stream.flush().await?;
612 Ok(())
613 }
614
615 pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
616 match self.state {
617 PgProtocolState::Startup => self
618 .stream
619 .read_startup()
620 .await
621 .map(|message: FeMessage| (message, None)),
622 PgProtocolState::Regular => {
623 self.stream.read_header().await?;
624 let guard = if let Some(ref header) = self.stream.read_header {
625 let payload_len = std::cmp::max(header.payload_len, 0) as u64;
626 let (reason, guard) = self.message_memory_manager.add(payload_len);
627 if let Some(reason) = reason {
628 drop(guard);
630 self.stream.skip_body().await?;
631 return Ok((FeMessage::ServerThrottle(reason), None));
632 }
633 guard
634 } else {
635 None
636 };
637 let message = self.stream.read_body().await?;
638 Ok((message, guard))
639 }
640 }
641 }
642
643 fn ready_for_query(&mut self) -> io::Result<()> {
645 self.stream.write_no_flush(BeMessage::ReadyForQuery(
646 self.session
647 .as_ref()
648 .map(|s| s.transaction_status())
649 .unwrap_or(TransactionStatus::Idle),
650 ))
651 }
652
653 async fn process_gss_msg(&mut self) -> PsqlResult<()> {
654 self.stream.write(BeMessage::EncryptionResponseNo).await?;
656 Ok(())
657 }
658
659 async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
660 if let Some(context) = self.tls_context.as_ref() {
661 self.stream.write(BeMessage::EncryptionResponseSsl).await?;
664 self.stream.upgrade_to_ssl(context).await?;
665 } else {
666 self.stream.write(BeMessage::EncryptionResponseNo).await?;
668 }
669
670 Ok(())
671 }
672
673 async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
674 if let Some(ref tls_config) = self.tls_config
676 && tls_config.enforce_ssl
677 && !self.stream.is_ssl_connection().await
678 {
679 return Err(PsqlError::StartupError(
680 "SSL connection is required but not established".into(),
681 ));
682 }
683
684 let db_name = msg
685 .config
686 .get("database")
687 .cloned()
688 .unwrap_or_else(|| "dev".to_owned());
689 let user_name = msg
690 .config
691 .get("user")
692 .cloned()
693 .unwrap_or_else(|| "root".to_owned());
694
695 let session = self
696 .session_mgr
697 .connect(&db_name, &user_name, self.peer_addr.clone())
698 .map_err(|e| PsqlError::StartupError(e.into()))?;
699
700 if let Some(options) = msg.config.get("options") {
701 for (key, value) in parse_options(options)? {
702 session
703 .set_config(&key, value)
704 .map_err(|e| PsqlError::StartupError(e.into()))?;
705 }
706 }
707 let application_name = msg.config.get("application_name");
709 if let Some(application_name) = application_name {
710 session
711 .set_config("application_name", application_name.clone())
712 .map_err(|e| PsqlError::StartupError(e.into()))?;
713 }
714
715 match session.user_authenticator() {
716 UserAuthenticator::None => {
717 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
718
719 self.stream
722 .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
723
724 self.stream.write_no_flush(BeMessage::ParameterStatus(
725 BeParameterStatusMessage::TimeZone(
726 &session
727 .get_config("timezone")
728 .map_err(|e| PsqlError::StartupError(e.into()))?,
729 ),
730 ))?;
731 self.stream
732 .write_parameter_status_msg_no_flush(&ParameterStatus {
733 application_name: application_name.cloned(),
734 })?;
735 self.ready_for_query()?;
736 }
737 UserAuthenticator::ClearText(_)
738 | UserAuthenticator::OAuth { .. }
739 | UserAuthenticator::Ldap(..) => {
740 self.stream
741 .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
742 }
743 UserAuthenticator::Md5WithSalt { salt, .. } => {
744 self.stream
745 .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
746 }
747 }
748
749 self.session = Some(session);
750 self.state = PgProtocolState::Regular;
751 Ok(())
752 }
753
754 async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
755 let session = self.session.as_ref().unwrap();
756 let authenticator = session.user_authenticator();
757 authenticator.authenticate(&msg.password).await?;
758 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
759 let timezone = session
760 .get_config("timezone")
761 .map_err(|e| PsqlError::StartupError(e.into()))?;
762 self.stream.write_no_flush(BeMessage::ParameterStatus(
763 BeParameterStatusMessage::TimeZone(&timezone),
764 ))?;
765 self.stream
766 .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
767 self.ready_for_query()?;
768 self.state = PgProtocolState::Regular;
769 Ok(())
770 }
771
772 fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
773 let session_id = (m.target_process_id, m.target_secret_key);
774 tracing::trace!("cancel query in session: {:?}", session_id);
775 self.session_mgr.cancel_queries_in_session(session_id);
776 self.session_mgr.cancel_creating_jobs_in_session(session_id);
777 self.is_terminate = true;
778 Ok(())
779 }
780
781 async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
782 let truncated_sql =
783 get_redacted_and_truncated_sql(&sql, self.redact_sql_option_keywords.clone());
784 let session = self.session.clone().unwrap();
785
786 session.check_idle_in_transaction_timeout()?;
787 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
789 self.inner_process_query_msg(sql, session.clone()).await
790 }
791
792 async fn inner_process_query_msg(
793 &mut self,
794 sql: Arc<str>,
795 session: Arc<SM::Session>,
796 ) -> PsqlResult<()> {
797 let stmts =
799 Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
800 drop(sql);
802 if stmts.is_empty() {
803 self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
804 }
805
806 for stmt in stmts {
808 self.inner_process_query_msg_one_stmt(stmt, session.clone())
809 .await?;
810 }
811 self.ready_for_query()?;
814 Ok(())
815 }
816
817 async fn inner_process_query_msg_one_stmt(
818 &mut self,
819 stmt: Statement,
820 session: Arc<SM::Session>,
821 ) -> PsqlResult<()> {
822 let session = session.clone();
823
824 let res = session.clone().run_one_query(stmt, Format::Text).await;
826
827 while let Some(notice) = session.next_notice().now_or_never() {
829 self.stream
830 .write_no_flush(BeMessage::NoticeResponse(¬ice))?;
831 }
832
833 let mut res = res.map_err(|e| PsqlError::SimpleQueryError(e.into()))?;
834
835 for notice in res.notices() {
836 self.stream
837 .write_no_flush(BeMessage::NoticeResponse(notice))?;
838 }
839
840 let status = res.status();
841 if let Some(ref application_name) = status.application_name {
842 self.stream.write_no_flush(BeMessage::ParameterStatus(
843 BeParameterStatusMessage::ApplicationName(application_name),
844 ))?;
845 }
846
847 if res.is_copy_query_to_stdout() {
848 self.stream
849 .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
850 let mut count = 0;
851 while let Some(row_set) = res.values_stream().next().await {
852 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
853 for row in row_set {
854 self.stream.write_no_flush(BeMessage::CopyData(&row))?;
855 count += 1;
856 }
857 }
858
859 self.stream.write_no_flush(BeMessage::CopyDone)?;
860
861 res.run_callback().await?;
863
864 self.stream
865 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
866 stmt_type: res.stmt_type(),
867 rows_cnt: count,
868 }))?;
869 } else if res.is_query() {
870 self.stream
871 .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
872
873 let mut rows_cnt = 0;
874
875 while let Some(row_set) = res.values_stream().next().await {
876 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
877 for row in row_set {
878 self.stream.write_no_flush(BeMessage::DataRow(&row))?;
879 rows_cnt += 1;
880 }
881 }
882
883 res.run_callback().await?;
885
886 self.stream
887 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
888 stmt_type: res.stmt_type(),
889 rows_cnt,
890 }))?;
891 } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
892 let first_row_set = res.values_stream().next().await;
893 let first_row_set = match first_row_set {
894 None => {
895 return Err(PsqlError::Uncategorized(
896 anyhow::anyhow!("no affected rows in output").into(),
897 ));
898 }
899 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
900 };
901 let affected_rows_str = first_row_set[0].values()[0]
902 .as_ref()
903 .expect("compute node should return affected rows in output");
904
905 assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
906 let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
907 .unwrap()
908 .parse()
909 .unwrap_or_default();
910
911 res.run_callback().await?;
913
914 self.stream
915 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
916 stmt_type: res.stmt_type(),
917 rows_cnt: affected_rows_cnt,
918 }))?;
919 } else {
920 res.run_callback().await?;
922
923 self.stream
924 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
925 stmt_type: res.stmt_type(),
926 rows_cnt: 0,
927 }))?;
928 }
929
930 Ok(())
931 }
932
933 fn process_terminate(&mut self) {
934 self.is_terminate = true;
935 }
936
937 fn process_health_check(&mut self) {
938 tracing::debug!("health check");
939 self.is_terminate = true;
940 }
941
942 async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
943 let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
944 let session = self.session.clone().unwrap();
945 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
946 let type_ids = std::mem::take(&mut msg.type_ids);
947 drop(msg);
949 self.inner_process_parse_msg(session, sql, statement_name, type_ids)
950 .await?;
951 Ok(())
952 }
953
954 async fn inner_process_parse_msg(
955 &mut self,
956 session: Arc<SM::Session>,
957 sql: Arc<str>,
958 statement_name: String,
959 type_ids: Vec<i32>,
960 ) -> PsqlResult<()> {
961 if statement_name.is_empty() {
962 self.unnamed_prepare_statement.take();
965 } else if self.prepare_statement_store.contains_key(&statement_name) {
966 return Err(PsqlError::ExtendedPrepareError(
967 "Duplicated statement name".into(),
968 ));
969 }
970
971 let stmt = {
972 let stmts = Parser::parse_sql(&sql)
973 .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
974 drop(sql);
975 if stmts.len() > 1 {
976 return Err(PsqlError::ExtendedPrepareError(
977 "Only one statement is allowed in extended query mode".into(),
978 ));
979 }
980
981 stmts.into_iter().next()
982 };
983
984 let param_types: Vec<Option<DataType>> = type_ids
985 .iter()
986 .map(|&id| {
987 if id == 0 {
990 Ok(None)
991 } else {
992 DataType::from_oid(id)
993 .map(Some)
994 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
995 }
996 })
997 .try_collect()?;
998
999 let prepare_statement = session
1000 .parse(stmt, param_types)
1001 .await
1002 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))?;
1003
1004 if statement_name.is_empty() {
1005 self.unnamed_prepare_statement.replace(prepare_statement);
1006 } else {
1007 self.prepare_statement_store
1008 .insert(statement_name.clone(), prepare_statement);
1009 }
1010
1011 self.statement_portal_dependency
1012 .entry(statement_name)
1013 .or_default()
1014 .clear();
1015
1016 self.stream.write_no_flush(BeMessage::ParseComplete)?;
1017 Ok(())
1018 }
1019
1020 fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
1021 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
1022 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1023 let session = self.session.clone().unwrap();
1024
1025 if self.portal_store.contains_key(&portal_name) {
1026 return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
1027 }
1028
1029 let prepare_statement = self.get_statement(&statement_name)?;
1030
1031 let result_formats = msg
1032 .result_format_codes
1033 .iter()
1034 .map(|&format_code| Format::from_i16(format_code))
1035 .try_collect()?;
1036 let param_formats = msg
1037 .param_format_codes
1038 .iter()
1039 .map(|&format_code| Format::from_i16(format_code))
1040 .try_collect()?;
1041
1042 let portal = session
1043 .bind(prepare_statement, msg.params, param_formats, result_formats)
1044 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1045
1046 if portal_name.is_empty() {
1047 self.result_cache.remove(&portal_name);
1048 self.unnamed_portal.replace(portal);
1049 } else {
1050 assert!(
1051 !self.result_cache.contains_key(&portal_name),
1052 "Named portal never can be overridden."
1053 );
1054 self.portal_store.insert(portal_name.clone(), portal);
1055 }
1056
1057 self.statement_portal_dependency
1058 .get_mut(&statement_name)
1059 .unwrap()
1060 .push(portal_name);
1061
1062 self.stream.write_no_flush(BeMessage::BindComplete)?;
1063 Ok(())
1064 }
1065
1066 async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1067 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1068 let row_max = msg.max_rows as usize;
1069 drop(msg);
1070 let session = self.session.clone().unwrap();
1071
1072 match self.result_cache.remove(&portal_name) {
1073 Some(mut result_cache) => {
1074 assert!(self.portal_store.contains_key(&portal_name));
1075
1076 let is_consume_completed =
1077 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1078
1079 if !is_consume_completed {
1080 self.result_cache.insert(portal_name, result_cache);
1081 }
1082 }
1083 _ => {
1084 let portal = self.get_portal(&portal_name)?;
1085 let sql = format!("{}", portal);
1086 let truncated_sql =
1087 get_redacted_and_truncated_sql(&sql, self.redact_sql_option_keywords.clone());
1088 drop(sql);
1089
1090 session.check_idle_in_transaction_timeout()?;
1091 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1093 let result = session.clone().execute(portal).await;
1094
1095 let pg_response = result.map_err(|e| PsqlError::ExtendedExecuteError(e.into()))?;
1096 let mut result_cache = ResultCache::new(pg_response);
1097 let is_consume_completed =
1098 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1099 if !is_consume_completed {
1100 self.result_cache.insert(portal_name, result_cache);
1101 }
1102 }
1103 }
1104
1105 Ok(())
1106 }
1107
1108 fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1109 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1110 let session = self.session.clone().unwrap();
1111 assert!(msg.kind == b'S' || msg.kind == b'P');
1115 if msg.kind == b'S' {
1116 let prepare_statement = self.get_statement(&name)?;
1117
1118 let (param_types, row_descriptions) = self
1119 .session
1120 .clone()
1121 .unwrap()
1122 .describe_statement(prepare_statement)
1123 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1124 self.stream.write_no_flush(BeMessage::ParameterDescription(
1125 ¶m_types.iter().map(|t| t.to_oid()).collect_vec(),
1126 ))?;
1127
1128 if row_descriptions.is_empty() {
1129 self.stream.write_no_flush(BeMessage::NoData)?;
1132 } else {
1133 self.stream
1134 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1135 }
1136 } else if msg.kind == b'P' {
1137 let portal = self.get_portal(&name)?;
1138
1139 let row_descriptions = session
1140 .describe_portal(portal)
1141 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1142
1143 if row_descriptions.is_empty() {
1144 self.stream.write_no_flush(BeMessage::NoData)?;
1147 } else {
1148 self.stream
1149 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1150 }
1151 }
1152 Ok(())
1153 }
1154
1155 fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1156 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1157 assert!(msg.kind == b'S' || msg.kind == b'P');
1158 if msg.kind == b'S' {
1159 if name.is_empty() {
1160 self.unnamed_prepare_statement = None;
1161 } else {
1162 self.prepare_statement_store.remove(&name);
1163 }
1164 for portal_name in self
1165 .statement_portal_dependency
1166 .remove(&name)
1167 .unwrap_or_default()
1168 {
1169 self.remove_portal(&portal_name);
1170 }
1171 } else if msg.kind == b'P' {
1172 self.remove_portal(&name);
1173 }
1174 self.stream.write_no_flush(BeMessage::CloseComplete)?;
1175 Ok(())
1176 }
1177
1178 fn remove_portal(&mut self, portal_name: &str) {
1179 if portal_name.is_empty() {
1180 self.unnamed_portal = None;
1181 } else {
1182 self.portal_store.remove(portal_name);
1183 }
1184 self.result_cache.remove(portal_name);
1185 }
1186
1187 fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1188 if portal_name.is_empty() {
1189 Ok(self
1190 .unnamed_portal
1191 .as_ref()
1192 .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1193 .clone())
1194 } else {
1195 Ok(self
1196 .portal_store
1197 .get(portal_name)
1198 .ok_or_else(|| {
1199 PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1200 })?
1201 .clone())
1202 }
1203 }
1204
1205 fn get_statement(
1206 &self,
1207 statement_name: &str,
1208 ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1209 if statement_name.is_empty() {
1210 Ok(self
1211 .unnamed_prepare_statement
1212 .as_ref()
1213 .ok_or_else(|| {
1214 PsqlError::Uncategorized("unnamed prepare statement not found".into())
1215 })?
1216 .clone())
1217 } else {
1218 Ok(self
1219 .prepare_statement_store
1220 .get(statement_name)
1221 .ok_or_else(|| {
1222 PsqlError::Uncategorized(
1223 format!("Prepare statement {} not found", statement_name).into(),
1224 )
1225 })?
1226 .clone())
1227 }
1228 }
1229}
1230
1231enum PgStreamInner<S> {
1232 Placeholder,
1234 Unencrypted(S),
1236 Ssl(SslStream<S>),
1238}
1239
1240pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1242impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1243
1244pub struct PgStream<S> {
1249 stream: Arc<Mutex<PgStreamInner<S>>>,
1251 write_buf: BytesMut,
1253 read_header: Option<FeMessageHeader>,
1254}
1255
1256impl<S> PgStream<S> {
1257 pub fn new(stream: S) -> Self {
1259 const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1260
1261 Self {
1262 stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1263 write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1264 read_header: None,
1265 }
1266 }
1267
1268 async fn is_ssl_connection(&self) -> bool {
1270 let stream = self.stream.lock().await;
1271 matches!(*stream, PgStreamInner::Ssl(_))
1272 }
1273}
1274
1275impl<S> Clone for PgStream<S> {
1276 fn clone(&self) -> Self {
1277 Self {
1278 stream: Arc::clone(&self.stream),
1279 write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1280 read_header: self.read_header.clone(),
1281 }
1282 }
1283}
1284
1285#[derive(Debug, Default, Clone)]
1302pub struct ParameterStatus {
1303 pub application_name: Option<String>,
1304}
1305
1306impl<S> PgStream<S>
1307where
1308 S: PgByteStream,
1309{
1310 async fn read_startup(&mut self) -> io::Result<FeMessage> {
1311 let mut stream = self.stream.lock().await;
1312 match &mut *stream {
1313 PgStreamInner::Placeholder => unreachable!(),
1314 PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1315 PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1316 }
1317 }
1318
1319 async fn read_header(&mut self) -> io::Result<()> {
1320 let mut stream = self.stream.lock().await;
1321 match &mut *stream {
1322 PgStreamInner::Placeholder => unreachable!(),
1323 PgStreamInner::Unencrypted(stream) => {
1324 self.read_header = Some(FeMessage::read_header(stream).await?);
1325 Ok(())
1326 }
1327 PgStreamInner::Ssl(ssl_stream) => {
1328 self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1329 Ok(())
1330 }
1331 }
1332 }
1333
1334 async fn read_body(&mut self) -> io::Result<FeMessage> {
1335 let mut stream = self.stream.lock().await;
1336 let header = self
1337 .read_header
1338 .take()
1339 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1340 match &mut *stream {
1341 PgStreamInner::Placeholder => unreachable!(),
1342 PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1343 PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1344 }
1345 }
1346
1347 async fn skip_body(&mut self) -> io::Result<()> {
1348 let mut stream = self.stream.lock().await;
1349 let header = self
1350 .read_header
1351 .take()
1352 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1353 match &mut *stream {
1354 PgStreamInner::Placeholder => unreachable!(),
1355 PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1356 PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1357 }
1358 }
1359
1360 fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1361 self.write_no_flush(BeMessage::ParameterStatus(
1362 BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1363 ))?;
1364 self.write_no_flush(BeMessage::ParameterStatus(
1365 BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1366 ))?;
1367 self.write_no_flush(BeMessage::ParameterStatus(
1368 BeParameterStatusMessage::ServerVersion(PG_VERSION),
1369 ))?;
1370 if let Some(application_name) = &status.application_name {
1371 self.write_no_flush(BeMessage::ParameterStatus(
1372 BeParameterStatusMessage::ApplicationName(application_name),
1373 ))?;
1374 }
1375 Ok(())
1376 }
1377
1378 pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1379 BeMessage::write(&mut self.write_buf, message)
1380 }
1381
1382 async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1383 self.write_no_flush(message)?;
1384 self.flush().await?;
1385 Ok(())
1386 }
1387
1388 async fn flush(&mut self) -> io::Result<()> {
1389 let mut stream = self.stream.lock().await;
1390 match &mut *stream {
1391 PgStreamInner::Placeholder => unreachable!(),
1392 PgStreamInner::Unencrypted(stream) => {
1393 stream.write_all(&self.write_buf).await?;
1394 stream.flush().await?;
1395 }
1396 PgStreamInner::Ssl(ssl_stream) => {
1397 ssl_stream.write_all(&self.write_buf).await?;
1398 ssl_stream.flush().await?;
1399 }
1400 }
1401 self.write_buf.clear();
1402 Ok(())
1403 }
1404}
1405
1406impl<S> PgStream<S>
1407where
1408 S: PgByteStream,
1409{
1410 async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1412 let mut stream = self.stream.lock().await;
1413
1414 match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1415 PgStreamInner::Unencrypted(unencrypted_stream) => {
1416 let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1417 let mut ssl_stream =
1418 tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1419
1420 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1421 tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1422 let _ = ssl_stream.shutdown().await;
1423 return Err(e.into());
1424 }
1425
1426 *stream = PgStreamInner::Ssl(ssl_stream);
1427 }
1428 PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1429 PgStreamInner::Placeholder => unreachable!(),
1430 }
1431
1432 Ok(())
1433 }
1434}
1435
1436fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1437 let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1438
1439 let key_path = &tls_config.key;
1440 let cert_path = &tls_config.cert;
1441
1442 acceptor
1445 .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1446 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1447 acceptor
1448 .set_ca_file(cert_path)
1449 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1450 acceptor
1451 .set_certificate_chain_file(cert_path)
1452 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1453 let acceptor = acceptor.build();
1454
1455 Ok(acceptor.into_context())
1456}
1457
1458pub mod truncated_fmt {
1459 use std::fmt::*;
1460
1461 struct TruncatedFormatter<'a, 'b> {
1462 remaining: usize,
1463 finished: bool,
1464 f: &'a mut Formatter<'b>,
1465 }
1466 impl Write for TruncatedFormatter<'_, '_> {
1467 fn write_str(&mut self, s: &str) -> Result {
1468 if self.finished {
1469 return Ok(());
1470 }
1471
1472 if self.remaining < s.len() {
1473 let actual = s.floor_char_boundary(self.remaining);
1474 self.f.write_str(&s[0..actual])?;
1475 self.remaining -= actual;
1476 self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1477 self.finished = true; } else {
1479 self.f.write_str(s)?;
1480 self.remaining -= s.len();
1481 }
1482 Ok(())
1483 }
1484 }
1485
1486 pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1487
1488 impl<T> Debug for TruncatedFmt<'_, T>
1489 where
1490 T: Debug,
1491 {
1492 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1493 TruncatedFormatter {
1494 remaining: self.1,
1495 finished: false,
1496 f,
1497 }
1498 .write_fmt(format_args!("{:?}", self.0))
1499 }
1500 }
1501
1502 impl<T> Display for TruncatedFmt<'_, T>
1503 where
1504 T: Display,
1505 {
1506 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1507 TruncatedFormatter {
1508 remaining: self.1,
1509 finished: false,
1510 f,
1511 }
1512 .write_fmt(format_args!("{}", self.0))
1513 }
1514 }
1515
1516 #[cfg(test)]
1517 mod tests {
1518 use super::*;
1519
1520 #[test]
1521 fn test_trunc_utf8() {
1522 assert_eq!(
1523 format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1524 "select '...(truncated,14)",
1525 );
1526 }
1527 }
1528}
1529
1530fn parse_options(options: &str) -> PsqlResult<Vec<(String, String)>> {
1543 let mut args = Vec::new();
1544 let mut current_arg = String::new();
1545 let mut chars = options.chars().peekable();
1546
1547 while let Some(c) = chars.next() {
1548 if c == '\\' {
1549 if let Some(next_c) = chars.next() {
1550 current_arg.push(next_c);
1551 }
1552 } else if c.is_ascii_whitespace() {
1553 if !current_arg.is_empty() {
1554 args.push(std::mem::take(&mut current_arg));
1555 }
1556 } else {
1557 current_arg.push(c);
1558 }
1559 }
1560 if !current_arg.is_empty() {
1561 args.push(current_arg);
1562 }
1563
1564 let mut args_iter = args.into_iter();
1565 let mut config = Vec::new();
1566
1567 while let Some(arg) = args_iter.next() {
1568 if arg == "-c" {
1569 if let Some(config_str) = args_iter.next() {
1570 if let Some((key, value)) = config_str.split_once('=') {
1571 let key = key.replace("-", "_");
1572 config.push((key, value.to_owned()));
1573 } else {
1574 return Err(PsqlError::StartupError(
1575 format!("invalid config format: {}", config_str).into(),
1576 ));
1577 }
1578 } else {
1579 return Err(PsqlError::StartupError("missing argument for -c".into()));
1580 }
1581 } else if let Some(config_str) = arg.strip_prefix("--") {
1582 if let Some((key, value)) = config_str.split_once('=') {
1583 let key = key.replace("-", "_");
1584 config.push((key, value.to_owned()));
1585 } else {
1586 return Err(PsqlError::StartupError(
1587 format!("invalid config format: {}", config_str).into(),
1588 ));
1589 }
1590 } else {
1591 tracing::warn!(
1592 arg,
1593 "ignoring unrecognized option for backward compatibility"
1594 );
1595 }
1596 }
1597 Ok(config)
1598}
1599
1600#[cfg(test)]
1601mod tests {
1602 use std::collections::HashSet;
1603
1604 use super::*;
1605
1606 #[test]
1607 fn test_redact_parsable_sql() {
1608 let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1609 let sql = r"
1610 create source temp (k bigint, v varchar) with (
1611 connector = 'datagen',
1612 v1 = 123,
1613 v2 = 'with',
1614 v3 = false,
1615 v4 = '',
1616 ) FORMAT plain ENCODE json (a='1',b='2')
1617 ";
1618 assert_eq!(
1619 redact_sql(sql, keywords),
1620 "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1621 );
1622 }
1623
1624 #[test]
1625 fn test_parse_options() {
1626 assert_eq!(parse_options("").unwrap(), vec![]);
1627 assert_eq!(
1628 parse_options("-c a=1 -c b=2").unwrap(),
1629 vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1630 );
1631 assert_eq!(
1632 parse_options("-c key=value").unwrap(),
1633 vec![("key".into(), "value".into())]
1634 );
1635 assert_eq!(
1637 parse_options("-c key='value'").unwrap(),
1638 vec![("key".into(), "'value'".into())]
1639 );
1640
1641 assert_eq!(
1643 parse_options(r#"-c key=value\ with\ spaces"#).unwrap(),
1644 vec![("key".into(), "value with spaces".into())]
1645 );
1646 assert_eq!(
1647 parse_options(r#"-c search_path=my\ schema"#).unwrap(),
1648 vec![("search_path".into(), "my schema".into())]
1649 );
1650
1651 assert!(parse_options("-c").is_err());
1652 assert!(parse_options("-c foo").is_err()); assert!(parse_options("--foo").is_err()); assert_eq!(
1656 parse_options("--foo=bar").unwrap(),
1657 vec![("foo".into(), "bar".into())]
1658 );
1659 assert_eq!(
1660 parse_options(r#"--foo=bar\ baz"#).unwrap(),
1661 vec![("foo".into(), "bar baz".into())]
1662 );
1663 assert_eq!(
1664 parse_options("-c a=1 --b=2").unwrap(),
1665 vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1666 );
1667 assert_eq!(
1669 parse_options(r#"-c a=b\"#).unwrap(),
1670 vec![("a".into(), "b".into())]
1671 );
1672 }
1673}