1use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::error_or_notice::Severity;
46use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
47use crate::net::AddressRef;
48use crate::pg_extended::ResultCache;
49use crate::pg_message::{
50 BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
51 FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
52 FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
53};
54use crate::pg_server::{Session, SessionManager, UserAuthenticator};
55use crate::types::Format;
56
57static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
60 LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
61 Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
62 _ => 65536,
63 });
64
65tokio::task_local! {
66 pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
68}
69
70pub struct PgProtocol<S, SM>
73where
74 SM: SessionManager,
75{
76 stream: PgStream<S>,
78 state: PgProtocolState,
80 is_terminate: bool,
82
83 session_mgr: Arc<SM>,
84 session: Option<Arc<SM::Session>>,
85
86 result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
87 unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
88 prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
89 unnamed_portal: Option<<SM::Session as Session>::Portal>,
90 portal_store: HashMap<String, <SM::Session as Session>::Portal>,
91 statement_portal_dependency: HashMap<String, Vec<String>>,
94
95 tls_context: Option<SslContext>,
98
99 tls_config: Option<TlsConfig>,
101
102 ignore_util_sync: bool,
105
106 peer_addr: AddressRef,
108
109 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
110 message_memory_manager: MessageMemoryManagerRef,
111}
112
113#[derive(Debug, Clone)]
115pub struct TlsConfig {
116 pub cert: String,
118 pub key: String,
120 pub enforce_ssl: bool,
122}
123
124impl TlsConfig {
125 pub fn new_default() -> Option<Self> {
126 let cert = std::env::var("RW_SSL_CERT").ok()?;
127 let key = std::env::var("RW_SSL_KEY").ok()?;
128 let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
129 tracing::info!(
130 "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
131 cert,
132 key,
133 enforce_ssl
134 );
135 Some(Self {
136 cert,
137 key,
138 enforce_ssl,
139 })
140 }
141}
142
143impl<S, SM> Drop for PgProtocol<S, SM>
144where
145 SM: SessionManager,
146{
147 fn drop(&mut self) {
148 if let Some(session) = &self.session {
149 self.session_mgr.end_session(session);
151 }
152 }
153}
154
155enum PgProtocolState {
157 Startup,
158 Regular,
159}
160
161pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
165 let without_null = if b.last() == Some(&0) {
166 &b[..b.len() - 1]
167 } else {
168 &b[..]
169 };
170 std::str::from_utf8(without_null)
171}
172
173fn record_sql_in_current_span(
174 sql: &str,
175 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
176) -> String {
177 let mut span = tracing::Span::current();
178 record_sql_in_span(sql, redact_sql_option_keywords, &mut span)
179}
180
181fn record_sql_in_span(
183 sql: &str,
184 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
185 span: &mut tracing::Span,
186) -> String {
187 let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
188 && !keywords.is_empty()
189 {
190 redact_sql(sql, keywords)
191 } else {
192 sql.to_owned()
193 };
194 let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
195 span.record("sql", tracing::field::display(&truncated));
196 truncated.to_string()
197}
198
199fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
201 match Parser::parse_sql(sql) {
202 Ok(sqls) => sqls
203 .into_iter()
204 .map(|sql| sql.to_redacted_string(keywords.clone()))
205 .join(";"),
206 Err(_) => sql.to_owned(),
207 }
208}
209
210#[derive(Clone)]
211pub struct ConnectionContext {
212 pub tls_config: Option<TlsConfig>,
213 pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
214 pub message_memory_manager: MessageMemoryManagerRef,
215}
216
217impl<S, SM> PgProtocol<S, SM>
218where
219 S: PgByteStream,
220 SM: SessionManager,
221{
222 pub fn new(
223 stream: S,
224 session_mgr: Arc<SM>,
225 peer_addr: AddressRef,
226 context: ConnectionContext,
227 ) -> Self {
228 let ConnectionContext {
229 tls_config,
230 redact_sql_option_keywords,
231 message_memory_manager,
232 } = context;
233 Self {
234 stream: PgStream::new(stream),
235 is_terminate: false,
236 state: PgProtocolState::Startup,
237 session_mgr,
238 session: None,
239 tls_context: tls_config
240 .as_ref()
241 .and_then(|e| build_ssl_ctx_from_config(e).ok()),
242 tls_config,
243 result_cache: Default::default(),
244 unnamed_prepare_statement: Default::default(),
245 prepare_statement_store: Default::default(),
246 unnamed_portal: Default::default(),
247 portal_store: Default::default(),
248 statement_portal_dependency: Default::default(),
249 ignore_util_sync: false,
250 peer_addr,
251 redact_sql_option_keywords,
252 message_memory_manager,
253 }
254 }
255
256 pub async fn run(&mut self) {
258 let mut notice_fut = None;
259
260 loop {
261 if notice_fut.is_none()
263 && let Some(session) = self.session.clone()
264 {
265 let mut stream = self.stream.clone();
266 notice_fut = Some(Box::pin(async move {
267 loop {
268 let notice = session.next_notice().await;
269 if let Err(e) = stream.write(BeMessage::NoticeResponse(¬ice)).await {
270 tracing::error!(error = %e.as_report(), notice, "failed to send notice");
271 }
272 }
273 }));
274 }
275
276 let process = std::pin::pin!(async {
278 let (msg, _memory_guard) = match self.read_message().await {
279 Ok(msg) => msg,
280 Err(e) => {
281 tracing::error!(error = %e.as_report(), "error when reading message");
282 return true; }
284 };
285 tracing::trace!(?msg, "received message");
286 self.process(msg).await
287 });
288
289 let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
290 tokio::select! {
291 _ = notice_fut => unreachable!(),
292 terminated = process => terminated,
293 }
294 } else {
295 process.await
296 };
297
298 if terminated {
299 break;
300 }
301 }
302 }
303
304 pub async fn process(&mut self, msg: FeMessage) -> bool {
306 self.do_process(msg).await.is_none() || self.is_terminate
307 }
308
309 fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
317 let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
318 return tracing::Span::none();
319 };
320
321 let mode = match msg {
322 FeMessage::Query(_) => "simple query",
323 FeMessage::Parse(_) => "extended query parse",
324 FeMessage::Execute(_) => "extended query execute",
325 _ => return tracing::Span::none(),
326 };
327
328 let mut span = tracing::info_span!(
329 target: PGWIRE_ROOT_SPAN_TARGET,
330 "handle_query",
331 mode,
332 session_id,
333 sql = tracing::field::Empty,
334 );
335 if let Ok(sql) = msg.get_sql()
336 && let Some(sql) = sql
337 {
338 record_sql_in_span(sql, self.redact_sql_option_keywords.clone(), &mut span);
339 }
340 span
341 }
342
343 async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
347 let span = self.root_span_for_msg(&msg);
348 let weak_session = self
349 .session
350 .as_ref()
351 .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
352
353 let fut = Box::pin(self.do_process_inner(msg));
358
359 let fut = async move {
361 if let Some(session) = weak_session {
362 CURRENT_SESSION.scope(session, fut).await
363 } else {
364 fut.await
365 }
366 };
367
368 let fut = async move {
370 AssertUnwindSafe(fut)
371 .rw_catch_unwind()
372 .await
373 .unwrap_or_else(|payload| {
374 Err(PsqlError::Panic(
375 panic_message::panic_message(&payload).to_owned(),
376 ))
377 })
378 };
379
380 let fut = async move {
382 let period = *SLOW_QUERY_LOG_PERIOD;
383 let mut fut = std::pin::pin!(fut);
384 let mut elapsed = Duration::ZERO;
385
386 loop {
388 match tokio::time::timeout(period, &mut fut).await {
389 Ok(result) => break result,
390 Err(_) => {
391 elapsed += period;
392 tracing::info!(
393 target: PGWIRE_SLOW_QUERY_LOG,
394 elapsed = %format_args!("{}ms", elapsed.as_millis()),
395 "slow query"
396 );
397 }
398 }
399 }
400 };
401
402 let fut = async move {
404 if !tracing::Span::current().is_none() {
405 tracing::info!(
406 target: PGWIRE_QUERY_LOG,
407 status = "started",
408 );
409 }
410
411 let start = Instant::now();
412 let result = fut.await;
413 let elapsed = start.elapsed();
414
415 if let Err(error) = &result {
419 if cfg!(debug_assertions) && !Deployment::current().is_ci() {
420 tracing::error!(error = ?error.as_report(), "error when process message");
426 } else {
427 tracing::error!(error = %error.as_report(), "error when process message");
428 }
429 }
430
431 if !tracing::Span::current().is_none() {
434 tracing::info!(
435 target: PGWIRE_QUERY_LOG,
436 status = if result.is_ok() { "ok" } else { "err" },
437 time = %format_args!("{}ms", elapsed.as_millis()),
438 );
439 }
440
441 result
442 };
443
444 let fut = fut.instrument(span);
446
447 match fut.await {
449 Ok(()) => Some(()),
450 Err(e) => {
451 match e {
452 PsqlError::IoError(io_err) => {
453 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
454 return None;
455 }
456 }
457
458 PsqlError::SslError(_) => {
459 return None;
462 }
463
464 PsqlError::StartupError(_) | PsqlError::PasswordError => {
465 self.stream
466 .write_no_flush(BeMessage::ErrorResponse {
467 error: &e,
468 pretty: false,
471 severity: Some(Severity::Fatal),
472 })
473 .ok()?;
474 let _ = self.stream.flush().await;
475 return None;
476 }
477
478 PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
479 self.stream
480 .write_no_flush(BeMessage::ErrorResponse {
481 error: &e,
482 pretty: true,
483 severity: None,
484 })
485 .ok()?;
486 self.ready_for_query().ok()?;
487 }
488
489 PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
490 self.stream
491 .write_no_flush(BeMessage::ErrorResponse {
492 error: &e,
493 pretty: true,
494 severity: None,
495 })
496 .ok()?;
497 let _ = self.stream.flush().await;
498
499 return None;
504 }
505
506 PsqlError::Uncategorized(_)
507 | PsqlError::ExtendedPrepareError(_)
508 | PsqlError::ExtendedExecuteError(_) => {
509 self.stream
510 .write_no_flush(BeMessage::ErrorResponse {
511 error: &e,
512 pretty: true,
513 severity: None,
514 })
515 .ok()?;
516 }
517 }
518 let _ = self.stream.flush().await;
519 Some(())
520 }
521 }
522 }
523
524 async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
525 if self.ignore_util_sync {
527 if let FeMessage::Sync = msg {
528 } else {
529 tracing::trace!("ignore message {:?} until sync.", msg);
530 return Ok(());
531 }
532 }
533
534 match msg {
535 FeMessage::Gss => self.process_gss_msg().await?,
536 FeMessage::Ssl => self.process_ssl_msg().await?,
537 FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
538 FeMessage::Password(msg) => self.process_password_msg(msg).await?,
539 FeMessage::Query(query_msg) => {
540 let sql = Arc::from(query_msg.get_sql()?);
541 drop(query_msg);
543 self.process_query_msg(sql).await?
544 }
545 FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
546 FeMessage::Terminate => self.process_terminate(),
547 FeMessage::Parse(m) => {
548 if let Err(err) = self.process_parse_msg(m).await {
549 self.ignore_util_sync = true;
550 return Err(err);
551 }
552 }
553 FeMessage::Bind(m) => {
554 if let Err(err) = self.process_bind_msg(m) {
555 self.ignore_util_sync = true;
556 return Err(err);
557 }
558 }
559 FeMessage::Execute(m) => {
560 if let Err(err) = self.process_execute_msg(m).await {
561 self.ignore_util_sync = true;
562 return Err(err);
563 }
564 }
565 FeMessage::Describe(m) => {
566 if let Err(err) = self.process_describe_msg(m) {
567 self.ignore_util_sync = true;
568 return Err(err);
569 }
570 }
571 FeMessage::Sync => {
572 self.ignore_util_sync = false;
573 self.ready_for_query()?
574 }
575 FeMessage::Close(m) => {
576 if let Err(err) = self.process_close_msg(m) {
577 self.ignore_util_sync = true;
578 return Err(err);
579 }
580 }
581 FeMessage::Flush => {
582 if let Err(err) = self.stream.flush().await {
583 self.ignore_util_sync = true;
584 return Err(err.into());
585 }
586 }
587 FeMessage::HealthCheck => self.process_health_check(),
588 FeMessage::ServerThrottle(reason) => match reason {
589 ServerThrottleReason::TooLargeMessage => {
590 return Err(PsqlError::ServerThrottle(format!(
591 "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
592 self.message_memory_manager.max_filter_bytes
593 )));
594 }
595 ServerThrottleReason::TooManyMemoryUsage => {
596 return Err(PsqlError::ServerThrottle(format!(
597 "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
598 self.message_memory_manager.max_running_bytes
599 )));
600 }
601 },
602 }
603 self.stream.flush().await?;
604 Ok(())
605 }
606
607 pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
608 match self.state {
609 PgProtocolState::Startup => self
610 .stream
611 .read_startup()
612 .await
613 .map(|message: FeMessage| (message, None)),
614 PgProtocolState::Regular => {
615 self.stream.read_header().await?;
616 let guard = if let Some(ref header) = self.stream.read_header {
617 let payload_len = std::cmp::max(header.payload_len, 0) as u64;
618 let (reason, guard) = self.message_memory_manager.add(payload_len);
619 if let Some(reason) = reason {
620 drop(guard);
622 self.stream.skip_body().await?;
623 return Ok((FeMessage::ServerThrottle(reason), None));
624 }
625 guard
626 } else {
627 None
628 };
629 let message = self.stream.read_body().await?;
630 Ok((message, guard))
631 }
632 }
633 }
634
635 fn ready_for_query(&mut self) -> io::Result<()> {
637 self.stream.write_no_flush(BeMessage::ReadyForQuery(
638 self.session
639 .as_ref()
640 .map(|s| s.transaction_status())
641 .unwrap_or(TransactionStatus::Idle),
642 ))
643 }
644
645 async fn process_gss_msg(&mut self) -> PsqlResult<()> {
646 self.stream.write(BeMessage::EncryptionResponseNo).await?;
648 Ok(())
649 }
650
651 async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
652 if let Some(context) = self.tls_context.as_ref() {
653 self.stream.write(BeMessage::EncryptionResponseSsl).await?;
656 self.stream.upgrade_to_ssl(context).await?;
657 } else {
658 self.stream.write(BeMessage::EncryptionResponseNo).await?;
660 }
661
662 Ok(())
663 }
664
665 async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
666 if let Some(ref tls_config) = self.tls_config
668 && tls_config.enforce_ssl
669 && !self.stream.is_ssl_connection().await
670 {
671 return Err(PsqlError::StartupError(
672 "SSL connection is required but not established".into(),
673 ));
674 }
675
676 let db_name = msg
677 .config
678 .get("database")
679 .cloned()
680 .unwrap_or_else(|| "dev".to_owned());
681 let user_name = msg
682 .config
683 .get("user")
684 .cloned()
685 .unwrap_or_else(|| "root".to_owned());
686
687 let session = self
688 .session_mgr
689 .connect(&db_name, &user_name, self.peer_addr.clone())
690 .map_err(|e| PsqlError::StartupError(e.into()))?;
691
692 if let Some(options) = msg.config.get("options") {
693 for (key, value) in parse_options(options)? {
694 session
695 .set_config(&key, value)
696 .map_err(|e| PsqlError::StartupError(e.into()))?;
697 }
698 }
699 let application_name = msg.config.get("application_name");
701 if let Some(application_name) = application_name {
702 session
703 .set_config("application_name", application_name.clone())
704 .map_err(|e| PsqlError::StartupError(e.into()))?;
705 }
706
707 match session.user_authenticator() {
708 UserAuthenticator::None => {
709 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
710
711 self.stream
714 .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
715
716 self.stream.write_no_flush(BeMessage::ParameterStatus(
717 BeParameterStatusMessage::TimeZone(
718 &session
719 .get_config("timezone")
720 .map_err(|e| PsqlError::StartupError(e.into()))?,
721 ),
722 ))?;
723 self.stream
724 .write_parameter_status_msg_no_flush(&ParameterStatus {
725 application_name: application_name.cloned(),
726 })?;
727 self.ready_for_query()?;
728 }
729 UserAuthenticator::ClearText(_)
730 | UserAuthenticator::OAuth { .. }
731 | UserAuthenticator::Ldap(..) => {
732 self.stream
733 .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
734 }
735 UserAuthenticator::Md5WithSalt { salt, .. } => {
736 self.stream
737 .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
738 }
739 }
740
741 self.session = Some(session);
742 self.state = PgProtocolState::Regular;
743 Ok(())
744 }
745
746 async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
747 let session = self.session.as_ref().unwrap();
748 let authenticator = session.user_authenticator();
749 authenticator.authenticate(&msg.password).await?;
750 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
751 let timezone = session
752 .get_config("timezone")
753 .map_err(|e| PsqlError::StartupError(e.into()))?;
754 self.stream.write_no_flush(BeMessage::ParameterStatus(
755 BeParameterStatusMessage::TimeZone(&timezone),
756 ))?;
757 self.stream
758 .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
759 self.ready_for_query()?;
760 self.state = PgProtocolState::Regular;
761 Ok(())
762 }
763
764 fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
765 let session_id = (m.target_process_id, m.target_secret_key);
766 tracing::trace!("cancel query in session: {:?}", session_id);
767 self.session_mgr.cancel_queries_in_session(session_id);
768 self.session_mgr.cancel_creating_jobs_in_session(session_id);
769 self.is_terminate = true;
770 Ok(())
771 }
772
773 async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
774 let truncated_sql =
775 record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
776 let session = self.session.clone().unwrap();
777
778 session.check_idle_in_transaction_timeout()?;
779 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
781 self.inner_process_query_msg(sql, session.clone()).await
782 }
783
784 async fn inner_process_query_msg(
785 &mut self,
786 sql: Arc<str>,
787 session: Arc<SM::Session>,
788 ) -> PsqlResult<()> {
789 let stmts =
791 Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
792 drop(sql);
794 if stmts.is_empty() {
795 self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
796 }
797
798 for stmt in stmts {
800 self.inner_process_query_msg_one_stmt(stmt, session.clone())
801 .await?;
802 }
803 self.ready_for_query()?;
806 Ok(())
807 }
808
809 async fn inner_process_query_msg_one_stmt(
810 &mut self,
811 stmt: Statement,
812 session: Arc<SM::Session>,
813 ) -> PsqlResult<()> {
814 let session = session.clone();
815
816 let res = session.clone().run_one_query(stmt, Format::Text).await;
818
819 while let Some(notice) = session.next_notice().now_or_never() {
821 self.stream
822 .write_no_flush(BeMessage::NoticeResponse(¬ice))?;
823 }
824
825 let mut res = res.map_err(|e| PsqlError::SimpleQueryError(e.into()))?;
826
827 for notice in res.notices() {
828 self.stream
829 .write_no_flush(BeMessage::NoticeResponse(notice))?;
830 }
831
832 let status = res.status();
833 if let Some(ref application_name) = status.application_name {
834 self.stream.write_no_flush(BeMessage::ParameterStatus(
835 BeParameterStatusMessage::ApplicationName(application_name),
836 ))?;
837 }
838
839 if res.is_copy_query_to_stdout() {
840 self.stream
841 .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
842 let mut count = 0;
843 while let Some(row_set) = res.values_stream().next().await {
844 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
845 for row in row_set {
846 self.stream.write_no_flush(BeMessage::CopyData(&row))?;
847 count += 1;
848 }
849 }
850
851 self.stream.write_no_flush(BeMessage::CopyDone)?;
852
853 res.run_callback().await?;
855
856 self.stream
857 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
858 stmt_type: res.stmt_type(),
859 rows_cnt: count,
860 }))?;
861 } else if res.is_query() {
862 self.stream
863 .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
864
865 let mut rows_cnt = 0;
866
867 while let Some(row_set) = res.values_stream().next().await {
868 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
869 for row in row_set {
870 self.stream.write_no_flush(BeMessage::DataRow(&row))?;
871 rows_cnt += 1;
872 }
873 }
874
875 res.run_callback().await?;
877
878 self.stream
879 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
880 stmt_type: res.stmt_type(),
881 rows_cnt,
882 }))?;
883 } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
884 let first_row_set = res.values_stream().next().await;
885 let first_row_set = match first_row_set {
886 None => {
887 return Err(PsqlError::Uncategorized(
888 anyhow::anyhow!("no affected rows in output").into(),
889 ));
890 }
891 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
892 };
893 let affected_rows_str = first_row_set[0].values()[0]
894 .as_ref()
895 .expect("compute node should return affected rows in output");
896
897 assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
898 let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
899 .unwrap()
900 .parse()
901 .unwrap_or_default();
902
903 res.run_callback().await?;
905
906 self.stream
907 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
908 stmt_type: res.stmt_type(),
909 rows_cnt: affected_rows_cnt,
910 }))?;
911 } else {
912 res.run_callback().await?;
914
915 self.stream
916 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
917 stmt_type: res.stmt_type(),
918 rows_cnt: 0,
919 }))?;
920 }
921
922 Ok(())
923 }
924
925 fn process_terminate(&mut self) {
926 self.is_terminate = true;
927 }
928
929 fn process_health_check(&mut self) {
930 tracing::debug!("health check");
931 self.is_terminate = true;
932 }
933
934 async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
935 let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
936 record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
937 let session = self.session.clone().unwrap();
938 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
939 let type_ids = std::mem::take(&mut msg.type_ids);
940 drop(msg);
942 self.inner_process_parse_msg(session, sql, statement_name, type_ids)
943 .await?;
944 Ok(())
945 }
946
947 async fn inner_process_parse_msg(
948 &mut self,
949 session: Arc<SM::Session>,
950 sql: Arc<str>,
951 statement_name: String,
952 type_ids: Vec<i32>,
953 ) -> PsqlResult<()> {
954 if statement_name.is_empty() {
955 self.unnamed_prepare_statement.take();
958 } else if self.prepare_statement_store.contains_key(&statement_name) {
959 return Err(PsqlError::ExtendedPrepareError(
960 "Duplicated statement name".into(),
961 ));
962 }
963
964 let stmt = {
965 let stmts = Parser::parse_sql(&sql)
966 .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
967 drop(sql);
968 if stmts.len() > 1 {
969 return Err(PsqlError::ExtendedPrepareError(
970 "Only one statement is allowed in extended query mode".into(),
971 ));
972 }
973
974 stmts.into_iter().next()
975 };
976
977 let param_types: Vec<Option<DataType>> = type_ids
978 .iter()
979 .map(|&id| {
980 if id == 0 {
983 Ok(None)
984 } else {
985 DataType::from_oid(id)
986 .map(Some)
987 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
988 }
989 })
990 .try_collect()?;
991
992 let prepare_statement = session
993 .parse(stmt, param_types)
994 .await
995 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))?;
996
997 if statement_name.is_empty() {
998 self.unnamed_prepare_statement.replace(prepare_statement);
999 } else {
1000 self.prepare_statement_store
1001 .insert(statement_name.clone(), prepare_statement);
1002 }
1003
1004 self.statement_portal_dependency
1005 .entry(statement_name)
1006 .or_default()
1007 .clear();
1008
1009 self.stream.write_no_flush(BeMessage::ParseComplete)?;
1010 Ok(())
1011 }
1012
1013 fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
1014 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
1015 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1016 let session = self.session.clone().unwrap();
1017
1018 if self.portal_store.contains_key(&portal_name) {
1019 return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
1020 }
1021
1022 let prepare_statement = self.get_statement(&statement_name)?;
1023
1024 let result_formats = msg
1025 .result_format_codes
1026 .iter()
1027 .map(|&format_code| Format::from_i16(format_code))
1028 .try_collect()?;
1029 let param_formats = msg
1030 .param_format_codes
1031 .iter()
1032 .map(|&format_code| Format::from_i16(format_code))
1033 .try_collect()?;
1034
1035 let portal = session
1036 .bind(prepare_statement, msg.params, param_formats, result_formats)
1037 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1038
1039 if portal_name.is_empty() {
1040 self.result_cache.remove(&portal_name);
1041 self.unnamed_portal.replace(portal);
1042 } else {
1043 assert!(
1044 !self.result_cache.contains_key(&portal_name),
1045 "Named portal never can be overridden."
1046 );
1047 self.portal_store.insert(portal_name.clone(), portal);
1048 }
1049
1050 self.statement_portal_dependency
1051 .get_mut(&statement_name)
1052 .unwrap()
1053 .push(portal_name);
1054
1055 self.stream.write_no_flush(BeMessage::BindComplete)?;
1056 Ok(())
1057 }
1058
1059 async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1060 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1061 let row_max = msg.max_rows as usize;
1062 drop(msg);
1063 let session = self.session.clone().unwrap();
1064
1065 match self.result_cache.remove(&portal_name) {
1066 Some(mut result_cache) => {
1067 assert!(self.portal_store.contains_key(&portal_name));
1068
1069 let is_consume_completed =
1070 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1071
1072 if !is_consume_completed {
1073 self.result_cache.insert(portal_name, result_cache);
1074 }
1075 }
1076 _ => {
1077 let portal = self.get_portal(&portal_name)?;
1078 let sql = format!("{}", portal);
1079 let truncated_sql =
1080 record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
1081 drop(sql);
1082
1083 session.check_idle_in_transaction_timeout()?;
1084 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1086 let result = session.clone().execute(portal).await;
1087
1088 let pg_response = result.map_err(|e| PsqlError::ExtendedExecuteError(e.into()))?;
1089 let mut result_cache = ResultCache::new(pg_response);
1090 let is_consume_completed =
1091 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1092 if !is_consume_completed {
1093 self.result_cache.insert(portal_name, result_cache);
1094 }
1095 }
1096 }
1097
1098 Ok(())
1099 }
1100
1101 fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1102 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1103 let session = self.session.clone().unwrap();
1104 assert!(msg.kind == b'S' || msg.kind == b'P');
1108 if msg.kind == b'S' {
1109 let prepare_statement = self.get_statement(&name)?;
1110
1111 let (param_types, row_descriptions) = self
1112 .session
1113 .clone()
1114 .unwrap()
1115 .describe_statement(prepare_statement)
1116 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1117 self.stream.write_no_flush(BeMessage::ParameterDescription(
1118 ¶m_types.iter().map(|t| t.to_oid()).collect_vec(),
1119 ))?;
1120
1121 if row_descriptions.is_empty() {
1122 self.stream.write_no_flush(BeMessage::NoData)?;
1125 } else {
1126 self.stream
1127 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1128 }
1129 } else if msg.kind == b'P' {
1130 let portal = self.get_portal(&name)?;
1131
1132 let row_descriptions = session
1133 .describe_portal(portal)
1134 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1135
1136 if row_descriptions.is_empty() {
1137 self.stream.write_no_flush(BeMessage::NoData)?;
1140 } else {
1141 self.stream
1142 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1143 }
1144 }
1145 Ok(())
1146 }
1147
1148 fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1149 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1150 assert!(msg.kind == b'S' || msg.kind == b'P');
1151 if msg.kind == b'S' {
1152 if name.is_empty() {
1153 self.unnamed_prepare_statement = None;
1154 } else {
1155 self.prepare_statement_store.remove(&name);
1156 }
1157 for portal_name in self
1158 .statement_portal_dependency
1159 .remove(&name)
1160 .unwrap_or_default()
1161 {
1162 self.remove_portal(&portal_name);
1163 }
1164 } else if msg.kind == b'P' {
1165 self.remove_portal(&name);
1166 }
1167 self.stream.write_no_flush(BeMessage::CloseComplete)?;
1168 Ok(())
1169 }
1170
1171 fn remove_portal(&mut self, portal_name: &str) {
1172 if portal_name.is_empty() {
1173 self.unnamed_portal = None;
1174 } else {
1175 self.portal_store.remove(portal_name);
1176 }
1177 self.result_cache.remove(portal_name);
1178 }
1179
1180 fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1181 if portal_name.is_empty() {
1182 Ok(self
1183 .unnamed_portal
1184 .as_ref()
1185 .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1186 .clone())
1187 } else {
1188 Ok(self
1189 .portal_store
1190 .get(portal_name)
1191 .ok_or_else(|| {
1192 PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1193 })?
1194 .clone())
1195 }
1196 }
1197
1198 fn get_statement(
1199 &self,
1200 statement_name: &str,
1201 ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1202 if statement_name.is_empty() {
1203 Ok(self
1204 .unnamed_prepare_statement
1205 .as_ref()
1206 .ok_or_else(|| {
1207 PsqlError::Uncategorized("unnamed prepare statement not found".into())
1208 })?
1209 .clone())
1210 } else {
1211 Ok(self
1212 .prepare_statement_store
1213 .get(statement_name)
1214 .ok_or_else(|| {
1215 PsqlError::Uncategorized(
1216 format!("Prepare statement {} not found", statement_name).into(),
1217 )
1218 })?
1219 .clone())
1220 }
1221 }
1222}
1223
1224enum PgStreamInner<S> {
1225 Placeholder,
1227 Unencrypted(S),
1229 Ssl(SslStream<S>),
1231}
1232
1233pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1235impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1236
1237pub struct PgStream<S> {
1242 stream: Arc<Mutex<PgStreamInner<S>>>,
1244 write_buf: BytesMut,
1246 read_header: Option<FeMessageHeader>,
1247}
1248
1249impl<S> PgStream<S> {
1250 pub fn new(stream: S) -> Self {
1252 const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1253
1254 Self {
1255 stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1256 write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1257 read_header: None,
1258 }
1259 }
1260
1261 async fn is_ssl_connection(&self) -> bool {
1263 let stream = self.stream.lock().await;
1264 matches!(*stream, PgStreamInner::Ssl(_))
1265 }
1266}
1267
1268impl<S> Clone for PgStream<S> {
1269 fn clone(&self) -> Self {
1270 Self {
1271 stream: Arc::clone(&self.stream),
1272 write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1273 read_header: self.read_header.clone(),
1274 }
1275 }
1276}
1277
1278#[derive(Debug, Default, Clone)]
1295pub struct ParameterStatus {
1296 pub application_name: Option<String>,
1297}
1298
1299impl<S> PgStream<S>
1300where
1301 S: PgByteStream,
1302{
1303 async fn read_startup(&mut self) -> io::Result<FeMessage> {
1304 let mut stream = self.stream.lock().await;
1305 match &mut *stream {
1306 PgStreamInner::Placeholder => unreachable!(),
1307 PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1308 PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1309 }
1310 }
1311
1312 async fn read_header(&mut self) -> io::Result<()> {
1313 let mut stream = self.stream.lock().await;
1314 match &mut *stream {
1315 PgStreamInner::Placeholder => unreachable!(),
1316 PgStreamInner::Unencrypted(stream) => {
1317 self.read_header = Some(FeMessage::read_header(stream).await?);
1318 Ok(())
1319 }
1320 PgStreamInner::Ssl(ssl_stream) => {
1321 self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1322 Ok(())
1323 }
1324 }
1325 }
1326
1327 async fn read_body(&mut self) -> io::Result<FeMessage> {
1328 let mut stream = self.stream.lock().await;
1329 let header = self
1330 .read_header
1331 .take()
1332 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1333 match &mut *stream {
1334 PgStreamInner::Placeholder => unreachable!(),
1335 PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1336 PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1337 }
1338 }
1339
1340 async fn skip_body(&mut self) -> io::Result<()> {
1341 let mut stream = self.stream.lock().await;
1342 let header = self
1343 .read_header
1344 .take()
1345 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1346 match &mut *stream {
1347 PgStreamInner::Placeholder => unreachable!(),
1348 PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1349 PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1350 }
1351 }
1352
1353 fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1354 self.write_no_flush(BeMessage::ParameterStatus(
1355 BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1356 ))?;
1357 self.write_no_flush(BeMessage::ParameterStatus(
1358 BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1359 ))?;
1360 self.write_no_flush(BeMessage::ParameterStatus(
1361 BeParameterStatusMessage::ServerVersion(PG_VERSION),
1362 ))?;
1363 if let Some(application_name) = &status.application_name {
1364 self.write_no_flush(BeMessage::ParameterStatus(
1365 BeParameterStatusMessage::ApplicationName(application_name),
1366 ))?;
1367 }
1368 Ok(())
1369 }
1370
1371 pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1372 BeMessage::write(&mut self.write_buf, message)
1373 }
1374
1375 async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1376 self.write_no_flush(message)?;
1377 self.flush().await?;
1378 Ok(())
1379 }
1380
1381 async fn flush(&mut self) -> io::Result<()> {
1382 let mut stream = self.stream.lock().await;
1383 match &mut *stream {
1384 PgStreamInner::Placeholder => unreachable!(),
1385 PgStreamInner::Unencrypted(stream) => {
1386 stream.write_all(&self.write_buf).await?;
1387 stream.flush().await?;
1388 }
1389 PgStreamInner::Ssl(ssl_stream) => {
1390 ssl_stream.write_all(&self.write_buf).await?;
1391 ssl_stream.flush().await?;
1392 }
1393 }
1394 self.write_buf.clear();
1395 Ok(())
1396 }
1397}
1398
1399impl<S> PgStream<S>
1400where
1401 S: PgByteStream,
1402{
1403 async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1405 let mut stream = self.stream.lock().await;
1406
1407 match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1408 PgStreamInner::Unencrypted(unencrypted_stream) => {
1409 let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1410 let mut ssl_stream =
1411 tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1412
1413 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1414 tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1415 let _ = ssl_stream.shutdown().await;
1416 return Err(e.into());
1417 }
1418
1419 *stream = PgStreamInner::Ssl(ssl_stream);
1420 }
1421 PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1422 PgStreamInner::Placeholder => unreachable!(),
1423 }
1424
1425 Ok(())
1426 }
1427}
1428
1429fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1430 let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1431
1432 let key_path = &tls_config.key;
1433 let cert_path = &tls_config.cert;
1434
1435 acceptor
1438 .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1439 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1440 acceptor
1441 .set_ca_file(cert_path)
1442 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1443 acceptor
1444 .set_certificate_chain_file(cert_path)
1445 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1446 let acceptor = acceptor.build();
1447
1448 Ok(acceptor.into_context())
1449}
1450
1451pub mod truncated_fmt {
1452 use std::fmt::*;
1453
1454 struct TruncatedFormatter<'a, 'b> {
1455 remaining: usize,
1456 finished: bool,
1457 f: &'a mut Formatter<'b>,
1458 }
1459 impl Write for TruncatedFormatter<'_, '_> {
1460 fn write_str(&mut self, s: &str) -> Result {
1461 if self.finished {
1462 return Ok(());
1463 }
1464
1465 if self.remaining < s.len() {
1466 let actual = s.floor_char_boundary(self.remaining);
1467 self.f.write_str(&s[0..actual])?;
1468 self.remaining -= actual;
1469 self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1470 self.finished = true; } else {
1472 self.f.write_str(s)?;
1473 self.remaining -= s.len();
1474 }
1475 Ok(())
1476 }
1477 }
1478
1479 pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1480
1481 impl<T> Debug for TruncatedFmt<'_, T>
1482 where
1483 T: Debug,
1484 {
1485 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1486 TruncatedFormatter {
1487 remaining: self.1,
1488 finished: false,
1489 f,
1490 }
1491 .write_fmt(format_args!("{:?}", self.0))
1492 }
1493 }
1494
1495 impl<T> Display for TruncatedFmt<'_, T>
1496 where
1497 T: Display,
1498 {
1499 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1500 TruncatedFormatter {
1501 remaining: self.1,
1502 finished: false,
1503 f,
1504 }
1505 .write_fmt(format_args!("{}", self.0))
1506 }
1507 }
1508
1509 #[cfg(test)]
1510 mod tests {
1511 use super::*;
1512
1513 #[test]
1514 fn test_trunc_utf8() {
1515 assert_eq!(
1516 format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1517 "select '...(truncated,14)",
1518 );
1519 }
1520 }
1521}
1522
1523fn parse_options(options: &str) -> PsqlResult<Vec<(String, String)>> {
1536 let mut args = Vec::new();
1537 let mut current_arg = String::new();
1538 let mut chars = options.chars().peekable();
1539
1540 while let Some(c) = chars.next() {
1541 if c == '\\' {
1542 if let Some(next_c) = chars.next() {
1543 current_arg.push(next_c);
1544 }
1545 } else if c.is_ascii_whitespace() {
1546 if !current_arg.is_empty() {
1547 args.push(std::mem::take(&mut current_arg));
1548 }
1549 } else {
1550 current_arg.push(c);
1551 }
1552 }
1553 if !current_arg.is_empty() {
1554 args.push(current_arg);
1555 }
1556
1557 let mut args_iter = args.into_iter();
1558 let mut config = Vec::new();
1559
1560 while let Some(arg) = args_iter.next() {
1561 if arg == "-c" {
1562 if let Some(config_str) = args_iter.next() {
1563 if let Some((key, value)) = config_str.split_once('=') {
1564 let key = key.replace("-", "_");
1565 config.push((key, value.to_owned()));
1566 } else {
1567 return Err(PsqlError::StartupError(
1568 format!("invalid config format: {}", config_str).into(),
1569 ));
1570 }
1571 } else {
1572 return Err(PsqlError::StartupError("missing argument for -c".into()));
1573 }
1574 } else if let Some(config_str) = arg.strip_prefix("--") {
1575 if let Some((key, value)) = config_str.split_once('=') {
1576 let key = key.replace("-", "_");
1577 config.push((key, value.to_owned()));
1578 } else {
1579 return Err(PsqlError::StartupError(
1580 format!("invalid config format: {}", config_str).into(),
1581 ));
1582 }
1583 } else {
1584 tracing::warn!(
1585 arg,
1586 "ignoring unrecognized option for backward compatibility"
1587 );
1588 }
1589 }
1590 Ok(config)
1591}
1592
1593#[cfg(test)]
1594mod tests {
1595 use std::collections::HashSet;
1596
1597 use super::*;
1598
1599 #[test]
1600 fn test_redact_parsable_sql() {
1601 let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1602 let sql = r"
1603 create source temp (k bigint, v varchar) with (
1604 connector = 'datagen',
1605 v1 = 123,
1606 v2 = 'with',
1607 v3 = false,
1608 v4 = '',
1609 ) FORMAT plain ENCODE json (a='1',b='2')
1610 ";
1611 assert_eq!(
1612 redact_sql(sql, keywords),
1613 "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1614 );
1615 }
1616
1617 #[test]
1618 fn test_parse_options() {
1619 assert_eq!(parse_options("").unwrap(), vec![]);
1620 assert_eq!(
1621 parse_options("-c a=1 -c b=2").unwrap(),
1622 vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1623 );
1624 assert_eq!(
1625 parse_options("-c key=value").unwrap(),
1626 vec![("key".into(), "value".into())]
1627 );
1628 assert_eq!(
1630 parse_options("-c key='value'").unwrap(),
1631 vec![("key".into(), "'value'".into())]
1632 );
1633
1634 assert_eq!(
1636 parse_options(r#"-c key=value\ with\ spaces"#).unwrap(),
1637 vec![("key".into(), "value with spaces".into())]
1638 );
1639 assert_eq!(
1640 parse_options(r#"-c search_path=my\ schema"#).unwrap(),
1641 vec![("search_path".into(), "my schema".into())]
1642 );
1643
1644 assert!(parse_options("-c").is_err());
1645 assert!(parse_options("-c foo").is_err()); assert!(parse_options("--foo").is_err()); assert_eq!(
1649 parse_options("--foo=bar").unwrap(),
1650 vec![("foo".into(), "bar".into())]
1651 );
1652 assert_eq!(
1653 parse_options(r#"--foo=bar\ baz"#).unwrap(),
1654 vec![("foo".into(), "bar baz".into())]
1655 );
1656 assert_eq!(
1657 parse_options("-c a=1 --b=2").unwrap(),
1658 vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1659 );
1660 assert_eq!(
1662 parse_options(r#"-c a=b\"#).unwrap(),
1663 vec![("a".into(), "b".into())]
1664 );
1665 }
1666}