1use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::panic::FutureCatchUnwindExt;
33use risingwave_common::util::query_log::*;
34use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
35use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
36use risingwave_sqlparser::parser::Parser;
37use thiserror_ext::AsReport;
38use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
39use tokio::sync::Mutex;
40use tokio_openssl::SslStream;
41use tracing::Instrument;
42
43use crate::error::{PsqlError, PsqlResult};
44use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
45use crate::net::AddressRef;
46use crate::pg_extended::ResultCache;
47use crate::pg_message::{
48 BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
49 FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
50 FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
51};
52use crate::pg_server::{Session, SessionManager, UserAuthenticator};
53use crate::types::Format;
54
55static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
58 LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
59 Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
60 _ => {
61 if cfg!(debug_assertions) {
62 65536
63 } else {
64 1024
65 }
66 }
67 });
68
69tokio::task_local! {
70 pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
72}
73
74pub struct PgProtocol<S, SM>
77where
78 SM: SessionManager,
79{
80 stream: PgStream<S>,
82 state: PgProtocolState,
84 is_terminate: bool,
86
87 session_mgr: Arc<SM>,
88 session: Option<Arc<SM::Session>>,
89
90 result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
91 unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
92 prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
93 unnamed_portal: Option<<SM::Session as Session>::Portal>,
94 portal_store: HashMap<String, <SM::Session as Session>::Portal>,
95 statement_portal_dependency: HashMap<String, Vec<String>>,
98
99 tls_context: Option<SslContext>,
102
103 ignore_util_sync: bool,
106
107 peer_addr: AddressRef,
109
110 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
111 message_memory_manager: MessageMemoryManagerRef,
112}
113
114#[derive(Debug, Clone)]
116pub struct TlsConfig {
117 pub cert: String,
119 pub key: String,
121}
122
123impl TlsConfig {
124 pub fn new_default() -> Option<Self> {
125 let cert = std::env::var("RW_SSL_CERT").ok()?;
126 let key = std::env::var("RW_SSL_KEY").ok()?;
127 tracing::info!("RW_SSL_CERT={}, RW_SSL_KEY={}", cert, key);
128 Some(Self { cert, key })
129 }
130}
131
132impl<S, SM> Drop for PgProtocol<S, SM>
133where
134 SM: SessionManager,
135{
136 fn drop(&mut self) {
137 if let Some(session) = &self.session {
138 self.session_mgr.end_session(session);
140 }
141 }
142}
143
144enum PgProtocolState {
146 Startup,
147 Regular,
148}
149
150pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
154 let without_null = if b.last() == Some(&0) {
155 &b[..b.len() - 1]
156 } else {
157 &b[..]
158 };
159 std::str::from_utf8(without_null)
160}
161
162fn record_sql_in_span(
164 sql: &str,
165 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
166) -> String {
167 let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
168 && !keywords.is_empty()
169 {
170 redact_sql(sql, keywords)
171 } else {
172 sql.to_owned()
173 };
174 let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
175 tracing::Span::current().record("sql", tracing::field::display(&truncated));
176 truncated.to_string()
177}
178
179fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
181 match Parser::parse_sql(sql) {
182 Ok(sqls) => sqls
183 .into_iter()
184 .map(|sql| sql.to_redacted_string(keywords.clone()))
185 .join(";"),
186 Err(_) => sql.to_owned(),
187 }
188}
189
190#[derive(Clone)]
191pub struct ConnectionContext {
192 pub tls_config: Option<TlsConfig>,
193 pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
194 pub message_memory_manager: MessageMemoryManagerRef,
195}
196
197impl<S, SM> PgProtocol<S, SM>
198where
199 S: PgByteStream,
200 SM: SessionManager,
201{
202 pub fn new(
203 stream: S,
204 session_mgr: Arc<SM>,
205 peer_addr: AddressRef,
206 context: ConnectionContext,
207 ) -> Self {
208 let ConnectionContext {
209 tls_config,
210 redact_sql_option_keywords,
211 message_memory_manager,
212 } = context;
213 Self {
214 stream: PgStream::new(stream),
215 is_terminate: false,
216 state: PgProtocolState::Startup,
217 session_mgr,
218 session: None,
219 tls_context: tls_config
220 .as_ref()
221 .and_then(|e| build_ssl_ctx_from_config(e).ok()),
222 result_cache: Default::default(),
223 unnamed_prepare_statement: Default::default(),
224 prepare_statement_store: Default::default(),
225 unnamed_portal: Default::default(),
226 portal_store: Default::default(),
227 statement_portal_dependency: Default::default(),
228 ignore_util_sync: false,
229 peer_addr,
230 redact_sql_option_keywords,
231 message_memory_manager,
232 }
233 }
234
235 pub async fn run(&mut self) {
237 let mut notice_fut = None;
238
239 loop {
240 if notice_fut.is_none()
242 && let Some(session) = self.session.clone()
243 {
244 let mut stream = self.stream.clone();
245 notice_fut = Some(Box::pin(async move {
246 loop {
247 let notice = session.next_notice().await;
248 if let Err(e) = stream.write(&BeMessage::NoticeResponse(¬ice)).await {
249 tracing::error!(error = %e.as_report(), notice, "failed to send notice");
250 }
251 }
252 }));
253 }
254
255 let process = std::pin::pin!(async {
257 let (msg, _memory_guard) = match self.read_message().await {
258 Ok(msg) => msg,
259 Err(e) => {
260 tracing::error!(error = %e.as_report(), "error when reading message");
261 return true; }
263 };
264 tracing::trace!(?msg, "received message");
265 self.process(msg).await
266 });
267
268 let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
269 tokio::select! {
270 _ = notice_fut => unreachable!(),
271 terminated = process => terminated,
272 }
273 } else {
274 process.await
275 };
276
277 if terminated {
278 break;
279 }
280 }
281 }
282
283 pub async fn process(&mut self, msg: FeMessage) -> bool {
285 self.do_process(msg).await.is_none() || self.is_terminate
286 }
287
288 fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
296 let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
297 return tracing::Span::none();
298 };
299
300 let mode = match msg {
301 FeMessage::Query(_) => "simple query",
302 FeMessage::Parse(_) => "extended query parse",
303 FeMessage::Execute(_) => "extended query execute",
304 _ => return tracing::Span::none(),
305 };
306
307 tracing::info_span!(
308 target: PGWIRE_ROOT_SPAN_TARGET,
309 "handle_query",
310 mode,
311 session_id,
312 sql = tracing::field::Empty, )
314 }
315
316 async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
320 let span = self.root_span_for_msg(&msg);
321 let weak_session = self
322 .session
323 .as_ref()
324 .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
325
326 let fut = Box::pin(self.do_process_inner(msg));
331
332 let fut = async move {
334 if let Some(session) = weak_session {
335 CURRENT_SESSION.scope(session, fut).await
336 } else {
337 fut.await
338 }
339 };
340
341 let fut = async move {
343 AssertUnwindSafe(fut)
344 .rw_catch_unwind()
345 .await
346 .unwrap_or_else(|payload| {
347 Err(PsqlError::Panic(
348 panic_message::panic_message(&payload).to_owned(),
349 ))
350 })
351 };
352
353 let fut = async move {
355 let period = *SLOW_QUERY_LOG_PERIOD;
356 let mut fut = std::pin::pin!(fut);
357 let mut elapsed = Duration::ZERO;
358
359 loop {
361 match tokio::time::timeout(period, &mut fut).await {
362 Ok(result) => break result,
363 Err(_) => {
364 elapsed += period;
365 tracing::info!(
366 target: PGWIRE_SLOW_QUERY_LOG,
367 elapsed = %format_args!("{}ms", elapsed.as_millis()),
368 "slow query"
369 );
370 }
371 }
372 }
373 };
374
375 let fut = async move {
377 let start = Instant::now();
378 let result = fut.await;
379 let elapsed = start.elapsed();
380
381 if let Err(error) = &result {
385 if cfg!(debug_assertions) && !Deployment::current().is_ci() {
386 tracing::error!(error = ?error.as_report(), "error when process message");
392 } else {
393 tracing::error!(error = %error.as_report(), "error when process message");
394 }
395 }
396
397 if !tracing::Span::current().is_none() {
400 tracing::info!(
401 target: PGWIRE_QUERY_LOG,
402 status = if result.is_ok() { "ok" } else { "err" },
403 time = %format_args!("{}ms", elapsed.as_millis()),
404 );
405 }
406
407 result
408 };
409
410 let fut = fut.instrument(span);
412
413 match fut.await {
415 Ok(()) => Some(()),
416 Err(e) => {
417 match e {
418 PsqlError::IoError(io_err) => {
419 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
420 return None;
421 }
422 }
423
424 PsqlError::SslError(_) => {
425 return None;
428 }
429
430 PsqlError::StartupError(_) | PsqlError::PasswordError => {
431 self.stream
432 .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
433 .ok()?;
434 let _ = self.stream.flush().await;
435 return None;
436 }
437
438 PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
439 self.stream
440 .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
441 .ok()?;
442 self.ready_for_query().ok()?;
443 }
444
445 PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
446 self.stream
447 .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
448 .ok()?;
449 let _ = self.stream.flush().await;
450
451 return None;
456 }
457
458 PsqlError::Uncategorized(_)
459 | PsqlError::ExtendedPrepareError(_)
460 | PsqlError::ExtendedExecuteError(_) => {
461 self.stream
462 .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
463 .ok()?;
464 }
465 }
466 let _ = self.stream.flush().await;
467 Some(())
468 }
469 }
470 }
471
472 async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
473 if self.ignore_util_sync {
475 if let FeMessage::Sync = msg {
476 } else {
477 tracing::trace!("ignore message {:?} until sync.", msg);
478 return Ok(());
479 }
480 }
481
482 match msg {
483 FeMessage::Gss => self.process_gss_msg().await?,
484 FeMessage::Ssl => self.process_ssl_msg().await?,
485 FeMessage::Startup(msg) => self.process_startup_msg(msg)?,
486 FeMessage::Password(msg) => self.process_password_msg(msg).await?,
487 FeMessage::Query(query_msg) => {
488 let sql = Arc::from(query_msg.get_sql()?);
489 drop(query_msg);
491 self.process_query_msg(sql).await?
492 }
493 FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
494 FeMessage::Terminate => self.process_terminate(),
495 FeMessage::Parse(m) => {
496 if let Err(err) = self.process_parse_msg(m).await {
497 self.ignore_util_sync = true;
498 return Err(err);
499 }
500 }
501 FeMessage::Bind(m) => {
502 if let Err(err) = self.process_bind_msg(m) {
503 self.ignore_util_sync = true;
504 return Err(err);
505 }
506 }
507 FeMessage::Execute(m) => {
508 if let Err(err) = self.process_execute_msg(m).await {
509 self.ignore_util_sync = true;
510 return Err(err);
511 }
512 }
513 FeMessage::Describe(m) => {
514 if let Err(err) = self.process_describe_msg(m) {
515 self.ignore_util_sync = true;
516 return Err(err);
517 }
518 }
519 FeMessage::Sync => {
520 self.ignore_util_sync = false;
521 self.ready_for_query()?
522 }
523 FeMessage::Close(m) => {
524 if let Err(err) = self.process_close_msg(m) {
525 self.ignore_util_sync = true;
526 return Err(err);
527 }
528 }
529 FeMessage::Flush => {
530 if let Err(err) = self.stream.flush().await {
531 self.ignore_util_sync = true;
532 return Err(err.into());
533 }
534 }
535 FeMessage::HealthCheck => self.process_health_check(),
536 FeMessage::ServerThrottle(reason) => match reason {
537 ServerThrottleReason::TooLargeMessage => {
538 return Err(PsqlError::ServerThrottle(format!(
539 "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
540 self.message_memory_manager.max_filter_bytes
541 )));
542 }
543 ServerThrottleReason::TooManyMemoryUsage => {
544 return Err(PsqlError::ServerThrottle(format!(
545 "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
546 self.message_memory_manager.max_running_bytes
547 )));
548 }
549 },
550 }
551 self.stream.flush().await?;
552 Ok(())
553 }
554
555 pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
556 match self.state {
557 PgProtocolState::Startup => self
558 .stream
559 .read_startup()
560 .await
561 .map(|message: FeMessage| (message, None)),
562 PgProtocolState::Regular => {
563 self.stream.read_header().await?;
564 let guard = if let Some(ref header) = self.stream.read_header {
565 let payload_len = std::cmp::max(header.payload_len, 0) as u64;
566 let (reason, guard) = self.message_memory_manager.add(payload_len);
567 if let Some(reason) = reason {
568 drop(guard);
570 self.stream.skip_body().await?;
571 return Ok((FeMessage::ServerThrottle(reason), None));
572 }
573 guard
574 } else {
575 None
576 };
577 let message = self.stream.read_body().await?;
578 Ok((message, guard))
579 }
580 }
581 }
582
583 fn ready_for_query(&mut self) -> io::Result<()> {
585 self.stream.write_no_flush(&BeMessage::ReadyForQuery(
586 self.session
587 .as_ref()
588 .map(|s| s.transaction_status())
589 .unwrap_or(TransactionStatus::Idle),
590 ))
591 }
592
593 async fn process_gss_msg(&mut self) -> PsqlResult<()> {
594 self.stream.write(&BeMessage::EncryptionResponseNo).await?;
596 Ok(())
597 }
598
599 async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
600 if let Some(context) = self.tls_context.as_ref() {
601 self.stream.write(&BeMessage::EncryptionResponseSsl).await?;
604 self.stream.upgrade_to_ssl(context).await?;
605 } else {
606 self.stream.write(&BeMessage::EncryptionResponseNo).await?;
608 }
609
610 Ok(())
611 }
612
613 fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
614 let db_name = msg
615 .config
616 .get("database")
617 .cloned()
618 .unwrap_or_else(|| "dev".to_owned());
619 let user_name = msg
620 .config
621 .get("user")
622 .cloned()
623 .unwrap_or_else(|| "root".to_owned());
624
625 let session = self
626 .session_mgr
627 .connect(&db_name, &user_name, self.peer_addr.clone())
628 .map_err(PsqlError::StartupError)?;
629
630 let application_name = msg.config.get("application_name");
631 if let Some(application_name) = application_name {
632 session
633 .set_config("application_name", application_name.clone())
634 .map_err(PsqlError::StartupError)?;
635 }
636
637 match session.user_authenticator() {
638 UserAuthenticator::None => {
639 self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
640
641 self.stream
644 .write_no_flush(&BeMessage::BackendKeyData(session.id()))?;
645
646 self.stream.write_no_flush(&BeMessage::ParameterStatus(
647 BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
648 ))?;
649 self.stream
650 .write_parameter_status_msg_no_flush(&ParameterStatus {
651 application_name: application_name.cloned(),
652 })?;
653 self.ready_for_query()?;
654 }
655 UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth(_) => {
656 self.stream
657 .write_no_flush(&BeMessage::AuthenticationCleartextPassword)?;
658 }
659 UserAuthenticator::Md5WithSalt { salt, .. } => {
660 self.stream
661 .write_no_flush(&BeMessage::AuthenticationMd5Password(salt))?;
662 }
663 }
664
665 self.session = Some(session);
666 self.state = PgProtocolState::Regular;
667 Ok(())
668 }
669
670 async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
671 let session = self.session.as_ref().unwrap();
672 let authenticator = session.user_authenticator();
673 authenticator.authenticate(&msg.password).await?;
674 self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
675 self.stream.write_no_flush(&BeMessage::ParameterStatus(
676 BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
677 ))?;
678 self.stream
679 .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
680 self.ready_for_query()?;
681 self.state = PgProtocolState::Regular;
682 Ok(())
683 }
684
685 fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
686 let session_id = (m.target_process_id, m.target_secret_key);
687 tracing::trace!("cancel query in session: {:?}", session_id);
688 self.session_mgr.cancel_queries_in_session(session_id);
689 self.session_mgr.cancel_creating_jobs_in_session(session_id);
690 self.stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
691 Ok(())
692 }
693
694 async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
695 let truncated_sql = record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
696 let session = self.session.clone().unwrap();
697
698 session.check_idle_in_transaction_timeout()?;
699 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
701 self.inner_process_query_msg(sql, session.clone()).await
702 }
703
704 async fn inner_process_query_msg(
705 &mut self,
706 sql: Arc<str>,
707 session: Arc<SM::Session>,
708 ) -> PsqlResult<()> {
709 let stmts =
711 Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
712 drop(sql);
714 if stmts.is_empty() {
715 self.stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
716 }
717
718 for stmt in stmts {
720 self.inner_process_query_msg_one_stmt(stmt, session.clone())
721 .await?;
722 }
723 self.ready_for_query()?;
726 Ok(())
727 }
728
729 async fn inner_process_query_msg_one_stmt(
730 &mut self,
731 stmt: Statement,
732 session: Arc<SM::Session>,
733 ) -> PsqlResult<()> {
734 let session = session.clone();
735
736 let res = session.clone().run_one_query(stmt, Format::Text).await;
738
739 while let Some(notice) = session.next_notice().now_or_never() {
741 self.stream
742 .write_no_flush(&BeMessage::NoticeResponse(¬ice))?;
743 }
744
745 let mut res = res.map_err(PsqlError::SimpleQueryError)?;
746
747 for notice in res.notices() {
748 self.stream
749 .write_no_flush(&BeMessage::NoticeResponse(notice))?;
750 }
751
752 let status = res.status();
753 if let Some(ref application_name) = status.application_name {
754 self.stream.write_no_flush(&BeMessage::ParameterStatus(
755 BeParameterStatusMessage::ApplicationName(application_name),
756 ))?;
757 }
758
759 if res.is_query() {
760 self.stream
761 .write_no_flush(&BeMessage::RowDescription(&res.row_desc()))?;
762
763 let mut rows_cnt = 0;
764
765 while let Some(row_set) = res.values_stream().next().await {
766 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
767 for row in row_set {
768 self.stream.write_no_flush(&BeMessage::DataRow(&row))?;
769 rows_cnt += 1;
770 }
771 }
772
773 res.run_callback().await?;
775
776 self.stream
777 .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
778 stmt_type: res.stmt_type(),
779 rows_cnt,
780 }))?;
781 } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
782 let first_row_set = res.values_stream().next().await;
783 let first_row_set = match first_row_set {
784 None => {
785 return Err(PsqlError::Uncategorized(
786 anyhow::anyhow!("no affected rows in output").into(),
787 ));
788 }
789 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
790 };
791 let affected_rows_str = first_row_set[0].values()[0]
792 .as_ref()
793 .expect("compute node should return affected rows in output");
794
795 assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
796 let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
797 .unwrap()
798 .parse()
799 .unwrap_or_default();
800
801 res.run_callback().await?;
803
804 self.stream
805 .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
806 stmt_type: res.stmt_type(),
807 rows_cnt: affected_rows_cnt,
808 }))?;
809 } else {
810 res.run_callback().await?;
812
813 self.stream
814 .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
815 stmt_type: res.stmt_type(),
816 rows_cnt: 0,
817 }))?;
818 }
819
820 Ok(())
821 }
822
823 fn process_terminate(&mut self) {
824 self.is_terminate = true;
825 }
826
827 fn process_health_check(&mut self) {
828 tracing::debug!("health check");
829 self.is_terminate = true;
830 }
831
832 async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
833 let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
834 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
835 let session = self.session.clone().unwrap();
836 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
837 let type_ids = std::mem::take(&mut msg.type_ids);
838 drop(msg);
840 self.inner_process_parse_msg(session, sql, statement_name, type_ids)
841 .await?;
842 Ok(())
843 }
844
845 async fn inner_process_parse_msg(
846 &mut self,
847 session: Arc<SM::Session>,
848 sql: Arc<str>,
849 statement_name: String,
850 type_ids: Vec<i32>,
851 ) -> PsqlResult<()> {
852 if statement_name.is_empty() {
853 self.unnamed_prepare_statement.take();
856 } else if self.prepare_statement_store.contains_key(&statement_name) {
857 return Err(PsqlError::ExtendedPrepareError(
858 "Duplicated statement name".into(),
859 ));
860 }
861
862 let stmt = {
863 let stmts = Parser::parse_sql(&sql)
864 .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
865 drop(sql);
866 if stmts.len() > 1 {
867 return Err(PsqlError::ExtendedPrepareError(
868 "Only one statement is allowed in extended query mode".into(),
869 ));
870 }
871
872 stmts.into_iter().next()
873 };
874
875 let param_types: Vec<Option<DataType>> = type_ids
876 .iter()
877 .map(|&id| {
878 if id == 0 {
881 Ok(None)
882 } else {
883 DataType::from_oid(id)
884 .map(Some)
885 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
886 }
887 })
888 .try_collect()?;
889
890 let prepare_statement = session
891 .parse(stmt, param_types)
892 .await
893 .map_err(PsqlError::ExtendedPrepareError)?;
894
895 if statement_name.is_empty() {
896 self.unnamed_prepare_statement.replace(prepare_statement);
897 } else {
898 self.prepare_statement_store
899 .insert(statement_name.clone(), prepare_statement);
900 }
901
902 self.statement_portal_dependency
903 .entry(statement_name)
904 .or_default()
905 .clear();
906
907 self.stream.write_no_flush(&BeMessage::ParseComplete)?;
908 Ok(())
909 }
910
911 fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
912 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
913 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
914 let session = self.session.clone().unwrap();
915
916 if self.portal_store.contains_key(&portal_name) {
917 return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
918 }
919
920 let prepare_statement = self.get_statement(&statement_name)?;
921
922 let result_formats = msg
923 .result_format_codes
924 .iter()
925 .map(|&format_code| Format::from_i16(format_code))
926 .try_collect()?;
927 let param_formats = msg
928 .param_format_codes
929 .iter()
930 .map(|&format_code| Format::from_i16(format_code))
931 .try_collect()?;
932
933 let portal = session
934 .bind(prepare_statement, msg.params, param_formats, result_formats)
935 .map_err(PsqlError::Uncategorized)?;
936
937 if portal_name.is_empty() {
938 self.result_cache.remove(&portal_name);
939 self.unnamed_portal.replace(portal);
940 } else {
941 assert!(
942 !self.result_cache.contains_key(&portal_name),
943 "Named portal never can be overridden."
944 );
945 self.portal_store.insert(portal_name.clone(), portal);
946 }
947
948 self.statement_portal_dependency
949 .get_mut(&statement_name)
950 .unwrap()
951 .push(portal_name);
952
953 self.stream.write_no_flush(&BeMessage::BindComplete)?;
954 Ok(())
955 }
956
957 async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
958 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
959 let row_max = msg.max_rows as usize;
960 drop(msg);
961 let session = self.session.clone().unwrap();
962
963 match self.result_cache.remove(&portal_name) {
964 Some(mut result_cache) => {
965 assert!(self.portal_store.contains_key(&portal_name));
966
967 let is_cosume_completed =
968 result_cache.consume::<S>(row_max, &mut self.stream).await?;
969
970 if !is_cosume_completed {
971 self.result_cache.insert(portal_name, result_cache);
972 }
973 }
974 _ => {
975 let portal = self.get_portal(&portal_name)?;
976 let sql = format!("{}", portal);
977 let truncated_sql =
978 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
979 drop(sql);
980
981 session.check_idle_in_transaction_timeout()?;
982 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
984 let result = session.clone().execute(portal).await;
985
986 let pg_response = result.map_err(PsqlError::ExtendedExecuteError)?;
987 let mut result_cache = ResultCache::new(pg_response);
988 let is_consume_completed =
989 result_cache.consume::<S>(row_max, &mut self.stream).await?;
990 if !is_consume_completed {
991 self.result_cache.insert(portal_name, result_cache);
992 }
993 }
994 }
995
996 Ok(())
997 }
998
999 fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1000 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1001 let session = self.session.clone().unwrap();
1002 assert!(msg.kind == b'S' || msg.kind == b'P');
1006 if msg.kind == b'S' {
1007 let prepare_statement = self.get_statement(&name)?;
1008
1009 let (param_types, row_descriptions) = self
1010 .session
1011 .clone()
1012 .unwrap()
1013 .describe_statement(prepare_statement)
1014 .map_err(PsqlError::Uncategorized)?;
1015 self.stream
1016 .write_no_flush(&BeMessage::ParameterDescription(
1017 ¶m_types.iter().map(|t| t.to_oid()).collect_vec(),
1018 ))?;
1019
1020 if row_descriptions.is_empty() {
1021 self.stream.write_no_flush(&BeMessage::NoData)?;
1024 } else {
1025 self.stream
1026 .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?;
1027 }
1028 } else if msg.kind == b'P' {
1029 let portal = self.get_portal(&name)?;
1030
1031 let row_descriptions = session
1032 .describe_portal(portal)
1033 .map_err(PsqlError::Uncategorized)?;
1034
1035 if row_descriptions.is_empty() {
1036 self.stream.write_no_flush(&BeMessage::NoData)?;
1039 } else {
1040 self.stream
1041 .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?;
1042 }
1043 }
1044 Ok(())
1045 }
1046
1047 fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1048 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1049 assert!(msg.kind == b'S' || msg.kind == b'P');
1050 if msg.kind == b'S' {
1051 if name.is_empty() {
1052 self.unnamed_prepare_statement = None;
1053 } else {
1054 self.prepare_statement_store.remove(&name);
1055 }
1056 for portal_name in self
1057 .statement_portal_dependency
1058 .remove(&name)
1059 .unwrap_or_default()
1060 {
1061 self.remove_portal(&portal_name);
1062 }
1063 } else if msg.kind == b'P' {
1064 self.remove_portal(&name);
1065 }
1066 self.stream.write_no_flush(&BeMessage::CloseComplete)?;
1067 Ok(())
1068 }
1069
1070 fn remove_portal(&mut self, portal_name: &str) {
1071 if portal_name.is_empty() {
1072 self.unnamed_portal = None;
1073 } else {
1074 self.portal_store.remove(portal_name);
1075 }
1076 self.result_cache.remove(portal_name);
1077 }
1078
1079 fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1080 if portal_name.is_empty() {
1081 Ok(self
1082 .unnamed_portal
1083 .as_ref()
1084 .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1085 .clone())
1086 } else {
1087 Ok(self
1088 .portal_store
1089 .get(portal_name)
1090 .ok_or_else(|| {
1091 PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1092 })?
1093 .clone())
1094 }
1095 }
1096
1097 fn get_statement(
1098 &self,
1099 statement_name: &str,
1100 ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1101 if statement_name.is_empty() {
1102 Ok(self
1103 .unnamed_prepare_statement
1104 .as_ref()
1105 .ok_or_else(|| {
1106 PsqlError::Uncategorized("unnamed prepare statement not found".into())
1107 })?
1108 .clone())
1109 } else {
1110 Ok(self
1111 .prepare_statement_store
1112 .get(statement_name)
1113 .ok_or_else(|| {
1114 PsqlError::Uncategorized(
1115 format!("Prepare statement {} not found", statement_name).into(),
1116 )
1117 })?
1118 .clone())
1119 }
1120 }
1121}
1122
1123enum PgStreamInner<S> {
1124 Placeholder,
1126 Unencrypted(S),
1128 Ssl(SslStream<S>),
1130}
1131
1132pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1134impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1135
1136pub struct PgStream<S> {
1141 stream: Arc<Mutex<PgStreamInner<S>>>,
1143 write_buf: BytesMut,
1145 read_header: Option<FeMessageHeader>,
1146}
1147
1148impl<S> PgStream<S> {
1149 pub fn new(stream: S) -> Self {
1151 const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1152
1153 Self {
1154 stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1155 write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1156 read_header: None,
1157 }
1158 }
1159}
1160
1161impl<S> Clone for PgStream<S> {
1162 fn clone(&self) -> Self {
1163 Self {
1164 stream: Arc::clone(&self.stream),
1165 write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1166 read_header: self.read_header.clone(),
1167 }
1168 }
1169}
1170
1171#[derive(Debug, Default, Clone)]
1188pub struct ParameterStatus {
1189 pub application_name: Option<String>,
1190}
1191
1192impl<S> PgStream<S>
1193where
1194 S: PgByteStream,
1195{
1196 async fn read_startup(&mut self) -> io::Result<FeMessage> {
1197 let mut stream = self.stream.lock().await;
1198 match &mut *stream {
1199 PgStreamInner::Placeholder => unreachable!(),
1200 PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1201 PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1202 }
1203 }
1204
1205 async fn read_header(&mut self) -> io::Result<()> {
1206 let mut stream = self.stream.lock().await;
1207 match &mut *stream {
1208 PgStreamInner::Placeholder => unreachable!(),
1209 PgStreamInner::Unencrypted(stream) => {
1210 self.read_header = Some(FeMessage::read_header(stream).await?);
1211 Ok(())
1212 }
1213 PgStreamInner::Ssl(ssl_stream) => {
1214 self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1215 Ok(())
1216 }
1217 }
1218 }
1219
1220 async fn read_body(&mut self) -> io::Result<FeMessage> {
1221 let mut stream = self.stream.lock().await;
1222 let header = self
1223 .read_header
1224 .take()
1225 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1226 match &mut *stream {
1227 PgStreamInner::Placeholder => unreachable!(),
1228 PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1229 PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1230 }
1231 }
1232
1233 async fn skip_body(&mut self) -> io::Result<()> {
1234 let mut stream = self.stream.lock().await;
1235 let header = self
1236 .read_header
1237 .take()
1238 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1239 match &mut *stream {
1240 PgStreamInner::Placeholder => unreachable!(),
1241 PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1242 PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1243 }
1244 }
1245
1246 fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1247 self.write_no_flush(&BeMessage::ParameterStatus(
1248 BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1249 ))?;
1250 self.write_no_flush(&BeMessage::ParameterStatus(
1251 BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1252 ))?;
1253 self.write_no_flush(&BeMessage::ParameterStatus(
1254 BeParameterStatusMessage::ServerVersion(PG_VERSION),
1255 ))?;
1256 if let Some(application_name) = &status.application_name {
1257 self.write_no_flush(&BeMessage::ParameterStatus(
1258 BeParameterStatusMessage::ApplicationName(application_name),
1259 ))?;
1260 }
1261 Ok(())
1262 }
1263
1264 pub fn write_no_flush(&mut self, message: &BeMessage<'_>) -> io::Result<()> {
1265 BeMessage::write(&mut self.write_buf, message)
1266 }
1267
1268 async fn write(&mut self, message: &BeMessage<'_>) -> io::Result<()> {
1269 self.write_no_flush(message)?;
1270 self.flush().await?;
1271 Ok(())
1272 }
1273
1274 async fn flush(&mut self) -> io::Result<()> {
1275 let mut stream = self.stream.lock().await;
1276 match &mut *stream {
1277 PgStreamInner::Placeholder => unreachable!(),
1278 PgStreamInner::Unencrypted(stream) => {
1279 stream.write_all(&self.write_buf).await?;
1280 stream.flush().await?;
1281 }
1282 PgStreamInner::Ssl(ssl_stream) => {
1283 ssl_stream.write_all(&self.write_buf).await?;
1284 ssl_stream.flush().await?;
1285 }
1286 }
1287 self.write_buf.clear();
1288 Ok(())
1289 }
1290}
1291
1292impl<S> PgStream<S>
1293where
1294 S: PgByteStream,
1295{
1296 async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1298 let mut stream = self.stream.lock().await;
1299
1300 match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1301 PgStreamInner::Unencrypted(unencrypted_stream) => {
1302 let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1303 let mut ssl_stream =
1304 tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1305
1306 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1307 tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1308 let _ = ssl_stream.shutdown().await;
1309 return Err(e.into());
1310 }
1311
1312 *stream = PgStreamInner::Ssl(ssl_stream);
1313 }
1314 PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1315 PgStreamInner::Placeholder => unreachable!(),
1316 }
1317
1318 Ok(())
1319 }
1320}
1321
1322fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1323 let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1324
1325 let key_path = &tls_config.key;
1326 let cert_path = &tls_config.cert;
1327
1328 acceptor
1331 .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1332 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1333 acceptor
1334 .set_ca_file(cert_path)
1335 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1336 acceptor
1337 .set_certificate_chain_file(cert_path)
1338 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1339 let acceptor = acceptor.build();
1340
1341 Ok(acceptor.into_context())
1342}
1343
1344pub mod truncated_fmt {
1345 use std::fmt::*;
1346
1347 struct TruncatedFormatter<'a, 'b> {
1348 remaining: usize,
1349 finished: bool,
1350 f: &'a mut Formatter<'b>,
1351 }
1352 impl Write for TruncatedFormatter<'_, '_> {
1353 fn write_str(&mut self, s: &str) -> Result {
1354 if self.finished {
1355 return Ok(());
1356 }
1357
1358 if self.remaining < s.len() {
1359 let actual = s.floor_char_boundary(self.remaining);
1360 self.f.write_str(&s[0..actual])?;
1361 self.remaining -= actual;
1362 self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1363 self.finished = true; } else {
1365 self.f.write_str(s)?;
1366 self.remaining -= s.len();
1367 }
1368 Ok(())
1369 }
1370 }
1371
1372 pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1373
1374 impl<T> Debug for TruncatedFmt<'_, T>
1375 where
1376 T: Debug,
1377 {
1378 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1379 TruncatedFormatter {
1380 remaining: self.1,
1381 finished: false,
1382 f,
1383 }
1384 .write_fmt(format_args!("{:?}", self.0))
1385 }
1386 }
1387
1388 impl<T> Display for TruncatedFmt<'_, T>
1389 where
1390 T: Display,
1391 {
1392 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1393 TruncatedFormatter {
1394 remaining: self.1,
1395 finished: false,
1396 f,
1397 }
1398 .write_fmt(format_args!("{}", self.0))
1399 }
1400 }
1401
1402 #[cfg(test)]
1403 mod tests {
1404 use super::*;
1405
1406 #[test]
1407 fn test_trunc_utf8() {
1408 assert_eq!(
1409 format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1410 "select '...(truncated,14)",
1411 );
1412 }
1413 }
1414}
1415
1416#[cfg(test)]
1417mod tests {
1418 use std::collections::HashSet;
1419
1420 use super::*;
1421
1422 #[test]
1423 fn test_redact_parsable_sql() {
1424 let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1425 let sql = r"
1426 create source temp (k bigint, v varchar) with (
1427 connector = 'datagen',
1428 v1 = 123,
1429 v2 = 'with',
1430 v3 = false,
1431 v4 = '',
1432 ) FORMAT plain ENCODE json (a='1',b='2')
1433 ";
1434 assert_eq!(
1435 redact_sql(sql, keywords),
1436 "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1437 );
1438 }
1439}