1use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::panic::FutureCatchUnwindExt;
32use risingwave_common::util::query_log::*;
33use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
34use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
35use risingwave_sqlparser::parser::Parser;
36use thiserror_ext::AsReport;
37use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
38use tokio::sync::Mutex;
39use tokio_openssl::SslStream;
40use tracing::Instrument;
41
42use crate::error::{PsqlError, PsqlResult};
43use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
44use crate::net::AddressRef;
45use crate::pg_extended::ResultCache;
46use crate::pg_message::{
47 BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
48 FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
49 FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
50};
51use crate::pg_server::{Session, SessionManager, UserAuthenticator};
52use crate::types::Format;
53
54static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
57 LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
58 Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
59 _ => {
60 if cfg!(debug_assertions) {
61 65536
62 } else {
63 1024
64 }
65 }
66 });
67
68tokio::task_local! {
69 pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
71}
72
73pub struct PgProtocol<S, SM>
76where
77 SM: SessionManager,
78{
79 stream: PgStream<S>,
81 state: PgProtocolState,
83 is_terminate: bool,
85
86 session_mgr: Arc<SM>,
87 session: Option<Arc<SM::Session>>,
88
89 result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
90 unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
91 prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
92 unnamed_portal: Option<<SM::Session as Session>::Portal>,
93 portal_store: HashMap<String, <SM::Session as Session>::Portal>,
94 statement_portal_dependency: HashMap<String, Vec<String>>,
97
98 tls_context: Option<SslContext>,
101
102 ignore_util_sync: bool,
105
106 peer_addr: AddressRef,
108
109 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
110 message_memory_manager: MessageMemoryManagerRef,
111}
112
113#[derive(Debug, Clone)]
115pub struct TlsConfig {
116 pub cert: String,
118 pub key: String,
120}
121
122impl TlsConfig {
123 pub fn new_default() -> Option<Self> {
124 let cert = std::env::var("RW_SSL_CERT").ok()?;
125 let key = std::env::var("RW_SSL_KEY").ok()?;
126 tracing::info!("RW_SSL_CERT={}, RW_SSL_KEY={}", cert, key);
127 Some(Self { cert, key })
128 }
129}
130
131impl<S, SM> Drop for PgProtocol<S, SM>
132where
133 SM: SessionManager,
134{
135 fn drop(&mut self) {
136 if let Some(session) = &self.session {
137 self.session_mgr.end_session(session);
139 }
140 }
141}
142
143enum PgProtocolState {
145 Startup,
146 Regular,
147}
148
149pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
153 let without_null = if b.last() == Some(&0) {
154 &b[..b.len() - 1]
155 } else {
156 &b[..]
157 };
158 std::str::from_utf8(without_null)
159}
160
161fn record_sql_in_span(
163 sql: &str,
164 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
165) -> String {
166 let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
167 && !keywords.is_empty()
168 {
169 redact_sql(sql, keywords)
170 } else {
171 sql.to_owned()
172 };
173 let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
174 tracing::Span::current().record("sql", tracing::field::display(&truncated));
175 truncated.to_string()
176}
177
178fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
180 match Parser::parse_sql(sql) {
181 Ok(sqls) => sqls
182 .into_iter()
183 .map(|sql| sql.to_redacted_string(keywords.clone()))
184 .join(";"),
185 Err(_) => sql.to_owned(),
186 }
187}
188
189#[derive(Clone)]
190pub struct ConnectionContext {
191 pub tls_config: Option<TlsConfig>,
192 pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
193 pub message_memory_manager: MessageMemoryManagerRef,
194}
195
196impl<S, SM> PgProtocol<S, SM>
197where
198 S: PgByteStream,
199 SM: SessionManager,
200{
201 pub fn new(
202 stream: S,
203 session_mgr: Arc<SM>,
204 peer_addr: AddressRef,
205 context: ConnectionContext,
206 ) -> Self {
207 let ConnectionContext {
208 tls_config,
209 redact_sql_option_keywords,
210 message_memory_manager,
211 } = context;
212 Self {
213 stream: PgStream::new(stream),
214 is_terminate: false,
215 state: PgProtocolState::Startup,
216 session_mgr,
217 session: None,
218 tls_context: tls_config
219 .as_ref()
220 .and_then(|e| build_ssl_ctx_from_config(e).ok()),
221 result_cache: Default::default(),
222 unnamed_prepare_statement: Default::default(),
223 prepare_statement_store: Default::default(),
224 unnamed_portal: Default::default(),
225 portal_store: Default::default(),
226 statement_portal_dependency: Default::default(),
227 ignore_util_sync: false,
228 peer_addr,
229 redact_sql_option_keywords,
230 message_memory_manager,
231 }
232 }
233
234 pub async fn run(&mut self) {
236 let mut notice_fut = None;
237
238 loop {
239 if notice_fut.is_none()
241 && let Some(session) = self.session.clone()
242 {
243 let mut stream = self.stream.clone();
244 notice_fut = Some(Box::pin(async move {
245 loop {
246 let notice = session.next_notice().await;
247 if let Err(e) = stream.write(&BeMessage::NoticeResponse(¬ice)).await {
248 tracing::error!(error = %e.as_report(), notice, "failed to send notice");
249 }
250 }
251 }));
252 }
253
254 let process = std::pin::pin!(async {
256 let (msg, _memory_guard) = match self.read_message().await {
257 Ok(msg) => msg,
258 Err(e) => {
259 tracing::error!(error = %e.as_report(), "error when reading message");
260 return true; }
262 };
263 tracing::trace!(?msg, "received message");
264 self.process(msg).await
265 });
266
267 let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
268 tokio::select! {
269 _ = notice_fut => unreachable!(),
270 terminated = process => terminated,
271 }
272 } else {
273 process.await
274 };
275
276 if terminated {
277 break;
278 }
279 }
280 }
281
282 pub async fn process(&mut self, msg: FeMessage) -> bool {
284 self.do_process(msg).await.is_none() || self.is_terminate
285 }
286
287 fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
295 let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
296 return tracing::Span::none();
297 };
298
299 let mode = match msg {
300 FeMessage::Query(_) => "simple query",
301 FeMessage::Parse(_) => "extended query parse",
302 FeMessage::Execute(_) => "extended query execute",
303 _ => return tracing::Span::none(),
304 };
305
306 tracing::info_span!(
307 target: PGWIRE_ROOT_SPAN_TARGET,
308 "handle_query",
309 mode,
310 session_id,
311 sql = tracing::field::Empty, )
313 }
314
315 async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
319 let span = self.root_span_for_msg(&msg);
320 let weak_session = self
321 .session
322 .as_ref()
323 .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
324
325 let fut = Box::pin(self.do_process_inner(msg));
330
331 let fut = async move {
333 if let Some(session) = weak_session {
334 CURRENT_SESSION.scope(session, fut).await
335 } else {
336 fut.await
337 }
338 };
339
340 let fut = async move {
342 AssertUnwindSafe(fut)
343 .rw_catch_unwind()
344 .await
345 .unwrap_or_else(|payload| {
346 Err(PsqlError::Panic(
347 panic_message::panic_message(&payload).to_owned(),
348 ))
349 })
350 };
351
352 let fut = async move {
354 let period = *SLOW_QUERY_LOG_PERIOD;
355 let mut fut = std::pin::pin!(fut);
356 let mut elapsed = Duration::ZERO;
357
358 loop {
360 match tokio::time::timeout(period, &mut fut).await {
361 Ok(result) => break result,
362 Err(_) => {
363 elapsed += period;
364 tracing::info!(
365 target: PGWIRE_SLOW_QUERY_LOG,
366 elapsed = %format_args!("{}ms", elapsed.as_millis()),
367 "slow query"
368 );
369 }
370 }
371 }
372 };
373
374 let fut = async move {
376 let start = Instant::now();
377 let result = fut.await;
378 let elapsed = start.elapsed();
379
380 if let Err(error) = &result {
384 tracing::error!(error = %error.as_report(), "error when process message");
385 }
386
387 if !tracing::Span::current().is_none() {
390 tracing::info!(
391 target: PGWIRE_QUERY_LOG,
392 status = if result.is_ok() { "ok" } else { "err" },
393 time = %format_args!("{}ms", elapsed.as_millis()),
394 );
395 }
396
397 result
398 };
399
400 let fut = fut.instrument(span);
402
403 match fut.await {
405 Ok(()) => Some(()),
406 Err(e) => {
407 match e {
408 PsqlError::IoError(io_err) => {
409 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
410 return None;
411 }
412 }
413
414 PsqlError::SslError(_) => {
415 return None;
418 }
419
420 PsqlError::StartupError(_) | PsqlError::PasswordError => {
421 self.stream
422 .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
423 .ok()?;
424 let _ = self.stream.flush().await;
425 return None;
426 }
427
428 PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
429 self.stream
430 .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
431 .ok()?;
432 self.ready_for_query().ok()?;
433 }
434
435 PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
436 self.stream
437 .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
438 .ok()?;
439 let _ = self.stream.flush().await;
440
441 return None;
446 }
447
448 PsqlError::Uncategorized(_)
449 | PsqlError::ExtendedPrepareError(_)
450 | PsqlError::ExtendedExecuteError(_) => {
451 self.stream
452 .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
453 .ok()?;
454 }
455 }
456 let _ = self.stream.flush().await;
457 Some(())
458 }
459 }
460 }
461
462 async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
463 if self.ignore_util_sync {
465 if let FeMessage::Sync = msg {
466 } else {
467 tracing::trace!("ignore message {:?} until sync.", msg);
468 return Ok(());
469 }
470 }
471
472 match msg {
473 FeMessage::Gss => self.process_gss_msg().await?,
474 FeMessage::Ssl => self.process_ssl_msg().await?,
475 FeMessage::Startup(msg) => self.process_startup_msg(msg)?,
476 FeMessage::Password(msg) => self.process_password_msg(msg).await?,
477 FeMessage::Query(query_msg) => {
478 let sql = Arc::from(query_msg.get_sql()?);
479 drop(query_msg);
481 self.process_query_msg(sql).await?
482 }
483 FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
484 FeMessage::Terminate => self.process_terminate(),
485 FeMessage::Parse(m) => {
486 if let Err(err) = self.process_parse_msg(m).await {
487 self.ignore_util_sync = true;
488 return Err(err);
489 }
490 }
491 FeMessage::Bind(m) => {
492 if let Err(err) = self.process_bind_msg(m) {
493 self.ignore_util_sync = true;
494 return Err(err);
495 }
496 }
497 FeMessage::Execute(m) => {
498 if let Err(err) = self.process_execute_msg(m).await {
499 self.ignore_util_sync = true;
500 return Err(err);
501 }
502 }
503 FeMessage::Describe(m) => {
504 if let Err(err) = self.process_describe_msg(m) {
505 self.ignore_util_sync = true;
506 return Err(err);
507 }
508 }
509 FeMessage::Sync => {
510 self.ignore_util_sync = false;
511 self.ready_for_query()?
512 }
513 FeMessage::Close(m) => {
514 if let Err(err) = self.process_close_msg(m) {
515 self.ignore_util_sync = true;
516 return Err(err);
517 }
518 }
519 FeMessage::Flush => {
520 if let Err(err) = self.stream.flush().await {
521 self.ignore_util_sync = true;
522 return Err(err.into());
523 }
524 }
525 FeMessage::HealthCheck => self.process_health_check(),
526 FeMessage::ServerThrottle(reason) => match reason {
527 ServerThrottleReason::TooLargeMessage => {
528 return Err(PsqlError::ServerThrottle(anyhow::anyhow!("max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit", self.message_memory_manager.max_filter_bytes).into()));
529 }
530 ServerThrottleReason::TooManyMemoryUsage => {
531 return Err(PsqlError::ServerThrottle(anyhow::anyhow!("max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit", self.message_memory_manager.max_running_bytes).into()));
532 }
533 },
534 }
535 self.stream.flush().await?;
536 Ok(())
537 }
538
539 pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
540 match self.state {
541 PgProtocolState::Startup => self
542 .stream
543 .read_startup()
544 .await
545 .map(|message: FeMessage| (message, None)),
546 PgProtocolState::Regular => {
547 self.stream.read_header().await?;
548 let guard = if let Some(ref header) = self.stream.read_header {
549 let payload_len = std::cmp::max(header.payload_len, 0) as u64;
550 let (reason, guard) = self.message_memory_manager.add(payload_len);
551 if let Some(reason) = reason {
552 drop(guard);
554 self.stream.skip_body().await?;
555 return Ok((FeMessage::ServerThrottle(reason), None));
556 }
557 guard
558 } else {
559 None
560 };
561 let message = self.stream.read_body().await?;
562 Ok((message, guard))
563 }
564 }
565 }
566
567 fn ready_for_query(&mut self) -> io::Result<()> {
569 self.stream.write_no_flush(&BeMessage::ReadyForQuery(
570 self.session
571 .as_ref()
572 .map(|s| s.transaction_status())
573 .unwrap_or(TransactionStatus::Idle),
574 ))
575 }
576
577 async fn process_gss_msg(&mut self) -> PsqlResult<()> {
578 self.stream.write(&BeMessage::EncryptionResponseNo).await?;
580 Ok(())
581 }
582
583 async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
584 if let Some(context) = self.tls_context.as_ref() {
585 self.stream.write(&BeMessage::EncryptionResponseSsl).await?;
588 self.stream.upgrade_to_ssl(context).await?;
589 } else {
590 self.stream.write(&BeMessage::EncryptionResponseNo).await?;
592 }
593
594 Ok(())
595 }
596
597 fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
598 let db_name = msg
599 .config
600 .get("database")
601 .cloned()
602 .unwrap_or_else(|| "dev".to_owned());
603 let user_name = msg
604 .config
605 .get("user")
606 .cloned()
607 .unwrap_or_else(|| "root".to_owned());
608
609 let session = self
610 .session_mgr
611 .connect(&db_name, &user_name, self.peer_addr.clone())
612 .map_err(PsqlError::StartupError)?;
613
614 let application_name = msg.config.get("application_name");
615 if let Some(application_name) = application_name {
616 session
617 .set_config("application_name", application_name.clone())
618 .map_err(PsqlError::StartupError)?;
619 }
620
621 match session.user_authenticator() {
622 UserAuthenticator::None => {
623 self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
624
625 self.stream
628 .write_no_flush(&BeMessage::BackendKeyData(session.id()))?;
629
630 self.stream
631 .write_parameter_status_msg_no_flush(&ParameterStatus {
632 application_name: application_name.cloned(),
633 })?;
634 self.ready_for_query()?;
635 }
636 UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth(_) => {
637 self.stream
638 .write_no_flush(&BeMessage::AuthenticationCleartextPassword)?;
639 }
640 UserAuthenticator::Md5WithSalt { salt, .. } => {
641 self.stream
642 .write_no_flush(&BeMessage::AuthenticationMd5Password(salt))?;
643 }
644 }
645
646 self.session = Some(session);
647 self.state = PgProtocolState::Regular;
648 Ok(())
649 }
650
651 async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
652 let authenticator = self.session.as_ref().unwrap().user_authenticator();
653 authenticator.authenticate(&msg.password).await?;
654 self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
655 self.stream
656 .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
657 self.ready_for_query()?;
658 self.state = PgProtocolState::Regular;
659 Ok(())
660 }
661
662 fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
663 let session_id = (m.target_process_id, m.target_secret_key);
664 tracing::trace!("cancel query in session: {:?}", session_id);
665 self.session_mgr.cancel_queries_in_session(session_id);
666 self.session_mgr.cancel_creating_jobs_in_session(session_id);
667 self.stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
668 Ok(())
669 }
670
671 async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
672 let truncated_sql = record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
673 let session = self.session.clone().unwrap();
674
675 session.check_idle_in_transaction_timeout()?;
676 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
678 self.inner_process_query_msg(sql, session.clone()).await
679 }
680
681 async fn inner_process_query_msg(
682 &mut self,
683 sql: Arc<str>,
684 session: Arc<SM::Session>,
685 ) -> PsqlResult<()> {
686 let stmts =
688 Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
689 drop(sql);
691 if stmts.is_empty() {
692 self.stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
693 }
694
695 for stmt in stmts {
697 self.inner_process_query_msg_one_stmt(stmt, session.clone())
698 .await?;
699 }
700 self.ready_for_query()?;
703 Ok(())
704 }
705
706 async fn inner_process_query_msg_one_stmt(
707 &mut self,
708 stmt: Statement,
709 session: Arc<SM::Session>,
710 ) -> PsqlResult<()> {
711 let session = session.clone();
712
713 let res = session.clone().run_one_query(stmt, Format::Text).await;
715
716 while let Some(notice) = session.next_notice().now_or_never() {
718 self.stream
719 .write_no_flush(&BeMessage::NoticeResponse(¬ice))?;
720 }
721
722 let mut res = res.map_err(PsqlError::SimpleQueryError)?;
723
724 for notice in res.notices() {
725 self.stream
726 .write_no_flush(&BeMessage::NoticeResponse(notice))?;
727 }
728
729 let status = res.status();
730 if let Some(ref application_name) = status.application_name {
731 self.stream.write_no_flush(&BeMessage::ParameterStatus(
732 BeParameterStatusMessage::ApplicationName(application_name),
733 ))?;
734 }
735
736 if res.is_query() {
737 self.stream
738 .write_no_flush(&BeMessage::RowDescription(&res.row_desc()))?;
739
740 let mut rows_cnt = 0;
741
742 while let Some(row_set) = res.values_stream().next().await {
743 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
744 for row in row_set {
745 self.stream.write_no_flush(&BeMessage::DataRow(&row))?;
746 rows_cnt += 1;
747 }
748 }
749
750 res.run_callback().await?;
752
753 self.stream
754 .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
755 stmt_type: res.stmt_type(),
756 rows_cnt,
757 }))?;
758 } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
759 let first_row_set = res.values_stream().next().await;
760 let first_row_set = match first_row_set {
761 None => {
762 return Err(PsqlError::Uncategorized(
763 anyhow::anyhow!("no affected rows in output").into(),
764 ));
765 }
766 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
767 };
768 let affected_rows_str = first_row_set[0].values()[0]
769 .as_ref()
770 .expect("compute node should return affected rows in output");
771
772 assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
773 let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
774 .unwrap()
775 .parse()
776 .unwrap_or_default();
777
778 res.run_callback().await?;
780
781 self.stream
782 .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
783 stmt_type: res.stmt_type(),
784 rows_cnt: affected_rows_cnt,
785 }))?;
786 } else {
787 res.run_callback().await?;
789
790 self.stream
791 .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
792 stmt_type: res.stmt_type(),
793 rows_cnt: 0,
794 }))?;
795 }
796
797 Ok(())
798 }
799
800 fn process_terminate(&mut self) {
801 self.is_terminate = true;
802 }
803
804 fn process_health_check(&mut self) {
805 tracing::debug!("health check");
806 self.is_terminate = true;
807 }
808
809 async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
810 let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
811 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
812 let session = self.session.clone().unwrap();
813 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
814 let type_ids = std::mem::take(&mut msg.type_ids);
815 drop(msg);
817 self.inner_process_parse_msg(session, sql, statement_name, type_ids)
818 .await?;
819 Ok(())
820 }
821
822 async fn inner_process_parse_msg(
823 &mut self,
824 session: Arc<SM::Session>,
825 sql: Arc<str>,
826 statement_name: String,
827 type_ids: Vec<i32>,
828 ) -> PsqlResult<()> {
829 if statement_name.is_empty() {
830 self.unnamed_prepare_statement.take();
833 } else if self.prepare_statement_store.contains_key(&statement_name) {
834 return Err(PsqlError::ExtendedPrepareError(
835 "Duplicated statement name".into(),
836 ));
837 }
838
839 let stmt = {
840 let stmts = Parser::parse_sql(&sql)
841 .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
842 drop(sql);
843 if stmts.len() > 1 {
844 return Err(PsqlError::ExtendedPrepareError(
845 "Only one statement is allowed in extended query mode".into(),
846 ));
847 }
848
849 stmts.into_iter().next()
850 };
851
852 let param_types: Vec<Option<DataType>> = type_ids
853 .iter()
854 .map(|&id| {
855 if id == 0 {
858 Ok(None)
859 } else {
860 DataType::from_oid(id)
861 .map(Some)
862 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
863 }
864 })
865 .try_collect()?;
866
867 let prepare_statement = session
868 .parse(stmt, param_types)
869 .await
870 .map_err(PsqlError::ExtendedPrepareError)?;
871
872 if statement_name.is_empty() {
873 self.unnamed_prepare_statement.replace(prepare_statement);
874 } else {
875 self.prepare_statement_store
876 .insert(statement_name.clone(), prepare_statement);
877 }
878
879 self.statement_portal_dependency
880 .entry(statement_name)
881 .or_default()
882 .clear();
883
884 self.stream.write_no_flush(&BeMessage::ParseComplete)?;
885 Ok(())
886 }
887
888 fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
889 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
890 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
891 let session = self.session.clone().unwrap();
892
893 if self.portal_store.contains_key(&portal_name) {
894 return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
895 }
896
897 let prepare_statement = self.get_statement(&statement_name)?;
898
899 let result_formats = msg
900 .result_format_codes
901 .iter()
902 .map(|&format_code| Format::from_i16(format_code))
903 .try_collect()?;
904 let param_formats = msg
905 .param_format_codes
906 .iter()
907 .map(|&format_code| Format::from_i16(format_code))
908 .try_collect()?;
909
910 let portal = session
911 .bind(prepare_statement, msg.params, param_formats, result_formats)
912 .map_err(PsqlError::Uncategorized)?;
913
914 if portal_name.is_empty() {
915 self.result_cache.remove(&portal_name);
916 self.unnamed_portal.replace(portal);
917 } else {
918 assert!(
919 !self.result_cache.contains_key(&portal_name),
920 "Named portal never can be overridden."
921 );
922 self.portal_store.insert(portal_name.clone(), portal);
923 }
924
925 self.statement_portal_dependency
926 .get_mut(&statement_name)
927 .unwrap()
928 .push(portal_name);
929
930 self.stream.write_no_flush(&BeMessage::BindComplete)?;
931 Ok(())
932 }
933
934 async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
935 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
936 let row_max = msg.max_rows as usize;
937 drop(msg);
938 let session = self.session.clone().unwrap();
939
940 match self.result_cache.remove(&portal_name) {
941 Some(mut result_cache) => {
942 assert!(self.portal_store.contains_key(&portal_name));
943
944 let is_cosume_completed =
945 result_cache.consume::<S>(row_max, &mut self.stream).await?;
946
947 if !is_cosume_completed {
948 self.result_cache.insert(portal_name, result_cache);
949 }
950 }
951 _ => {
952 let portal = self.get_portal(&portal_name)?;
953 let sql = format!("{}", portal);
954 let truncated_sql =
955 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
956 drop(sql);
957
958 session.check_idle_in_transaction_timeout()?;
959 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
961 let result = session.clone().execute(portal).await;
962
963 let pg_response = result.map_err(PsqlError::ExtendedExecuteError)?;
964 let mut result_cache = ResultCache::new(pg_response);
965 let is_consume_completed =
966 result_cache.consume::<S>(row_max, &mut self.stream).await?;
967 if !is_consume_completed {
968 self.result_cache.insert(portal_name, result_cache);
969 }
970 }
971 }
972
973 Ok(())
974 }
975
976 fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
977 let name = cstr_to_str(&msg.name).unwrap().to_owned();
978 let session = self.session.clone().unwrap();
979 assert!(msg.kind == b'S' || msg.kind == b'P');
983 if msg.kind == b'S' {
984 let prepare_statement = self.get_statement(&name)?;
985
986 let (param_types, row_descriptions) = self
987 .session
988 .clone()
989 .unwrap()
990 .describe_statement(prepare_statement)
991 .map_err(PsqlError::Uncategorized)?;
992 self.stream
993 .write_no_flush(&BeMessage::ParameterDescription(
994 ¶m_types.iter().map(|t| t.to_oid()).collect_vec(),
995 ))?;
996
997 if row_descriptions.is_empty() {
998 self.stream.write_no_flush(&BeMessage::NoData)?;
1001 } else {
1002 self.stream
1003 .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?;
1004 }
1005 } else if msg.kind == b'P' {
1006 let portal = self.get_portal(&name)?;
1007
1008 let row_descriptions = session
1009 .describe_portal(portal)
1010 .map_err(PsqlError::Uncategorized)?;
1011
1012 if row_descriptions.is_empty() {
1013 self.stream.write_no_flush(&BeMessage::NoData)?;
1016 } else {
1017 self.stream
1018 .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?;
1019 }
1020 }
1021 Ok(())
1022 }
1023
1024 fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1025 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1026 assert!(msg.kind == b'S' || msg.kind == b'P');
1027 if msg.kind == b'S' {
1028 if name.is_empty() {
1029 self.unnamed_prepare_statement = None;
1030 } else {
1031 self.prepare_statement_store.remove(&name);
1032 }
1033 for portal_name in self
1034 .statement_portal_dependency
1035 .remove(&name)
1036 .unwrap_or_default()
1037 {
1038 self.remove_portal(&portal_name);
1039 }
1040 } else if msg.kind == b'P' {
1041 self.remove_portal(&name);
1042 }
1043 self.stream.write_no_flush(&BeMessage::CloseComplete)?;
1044 Ok(())
1045 }
1046
1047 fn remove_portal(&mut self, portal_name: &str) {
1048 if portal_name.is_empty() {
1049 self.unnamed_portal = None;
1050 } else {
1051 self.portal_store.remove(portal_name);
1052 }
1053 self.result_cache.remove(portal_name);
1054 }
1055
1056 fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1057 if portal_name.is_empty() {
1058 Ok(self
1059 .unnamed_portal
1060 .as_ref()
1061 .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1062 .clone())
1063 } else {
1064 Ok(self
1065 .portal_store
1066 .get(portal_name)
1067 .ok_or_else(|| {
1068 PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1069 })?
1070 .clone())
1071 }
1072 }
1073
1074 fn get_statement(
1075 &self,
1076 statement_name: &str,
1077 ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1078 if statement_name.is_empty() {
1079 Ok(self
1080 .unnamed_prepare_statement
1081 .as_ref()
1082 .ok_or_else(|| {
1083 PsqlError::Uncategorized("unnamed prepare statement not found".into())
1084 })?
1085 .clone())
1086 } else {
1087 Ok(self
1088 .prepare_statement_store
1089 .get(statement_name)
1090 .ok_or_else(|| {
1091 PsqlError::Uncategorized(
1092 format!("Prepare statement {} not found", statement_name).into(),
1093 )
1094 })?
1095 .clone())
1096 }
1097 }
1098}
1099
1100enum PgStreamInner<S> {
1101 Placeholder,
1103 Unencrypted(S),
1105 Ssl(SslStream<S>),
1107}
1108
1109pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1111impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1112
1113pub struct PgStream<S> {
1118 stream: Arc<Mutex<PgStreamInner<S>>>,
1120 write_buf: BytesMut,
1122 read_header: Option<FeMessageHeader>,
1123}
1124
1125impl<S> PgStream<S> {
1126 pub fn new(stream: S) -> Self {
1128 const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1129
1130 Self {
1131 stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1132 write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1133 read_header: None,
1134 }
1135 }
1136}
1137
1138impl<S> Clone for PgStream<S> {
1139 fn clone(&self) -> Self {
1140 Self {
1141 stream: Arc::clone(&self.stream),
1142 write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1143 read_header: self.read_header.clone(),
1144 }
1145 }
1146}
1147
1148#[derive(Debug, Default, Clone)]
1165pub struct ParameterStatus {
1166 pub application_name: Option<String>,
1167}
1168
1169impl<S> PgStream<S>
1170where
1171 S: PgByteStream,
1172{
1173 async fn read_startup(&mut self) -> io::Result<FeMessage> {
1174 let mut stream = self.stream.lock().await;
1175 match &mut *stream {
1176 PgStreamInner::Placeholder => unreachable!(),
1177 PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1178 PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1179 }
1180 }
1181
1182 async fn read_header(&mut self) -> io::Result<()> {
1183 let mut stream = self.stream.lock().await;
1184 match &mut *stream {
1185 PgStreamInner::Placeholder => unreachable!(),
1186 PgStreamInner::Unencrypted(stream) => {
1187 self.read_header = Some(FeMessage::read_header(stream).await?);
1188 Ok(())
1189 }
1190 PgStreamInner::Ssl(ssl_stream) => {
1191 self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1192 Ok(())
1193 }
1194 }
1195 }
1196
1197 async fn read_body(&mut self) -> io::Result<FeMessage> {
1198 let mut stream = self.stream.lock().await;
1199 let header = self
1200 .read_header
1201 .take()
1202 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1203 match &mut *stream {
1204 PgStreamInner::Placeholder => unreachable!(),
1205 PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1206 PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1207 }
1208 }
1209
1210 async fn skip_body(&mut self) -> io::Result<()> {
1211 let mut stream = self.stream.lock().await;
1212 let header = self
1213 .read_header
1214 .take()
1215 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1216 match &mut *stream {
1217 PgStreamInner::Placeholder => unreachable!(),
1218 PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1219 PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1220 }
1221 }
1222
1223 fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1224 self.write_no_flush(&BeMessage::ParameterStatus(
1225 BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1226 ))?;
1227 self.write_no_flush(&BeMessage::ParameterStatus(
1228 BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1229 ))?;
1230 self.write_no_flush(&BeMessage::ParameterStatus(
1231 BeParameterStatusMessage::ServerVersion(PG_VERSION),
1232 ))?;
1233 if let Some(application_name) = &status.application_name {
1234 self.write_no_flush(&BeMessage::ParameterStatus(
1235 BeParameterStatusMessage::ApplicationName(application_name),
1236 ))?;
1237 }
1238 Ok(())
1239 }
1240
1241 pub fn write_no_flush(&mut self, message: &BeMessage<'_>) -> io::Result<()> {
1242 BeMessage::write(&mut self.write_buf, message)
1243 }
1244
1245 async fn write(&mut self, message: &BeMessage<'_>) -> io::Result<()> {
1246 self.write_no_flush(message)?;
1247 self.flush().await?;
1248 Ok(())
1249 }
1250
1251 async fn flush(&mut self) -> io::Result<()> {
1252 let mut stream = self.stream.lock().await;
1253 match &mut *stream {
1254 PgStreamInner::Placeholder => unreachable!(),
1255 PgStreamInner::Unencrypted(stream) => {
1256 stream.write_all(&self.write_buf).await?;
1257 stream.flush().await?;
1258 }
1259 PgStreamInner::Ssl(ssl_stream) => {
1260 ssl_stream.write_all(&self.write_buf).await?;
1261 ssl_stream.flush().await?;
1262 }
1263 }
1264 self.write_buf.clear();
1265 Ok(())
1266 }
1267}
1268
1269impl<S> PgStream<S>
1270where
1271 S: PgByteStream,
1272{
1273 async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1275 let mut stream = self.stream.lock().await;
1276
1277 match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1278 PgStreamInner::Unencrypted(unencrypted_stream) => {
1279 let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1280 let mut ssl_stream =
1281 tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1282
1283 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1284 tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1285 let _ = ssl_stream.shutdown().await;
1286 return Err(e.into());
1287 }
1288
1289 *stream = PgStreamInner::Ssl(ssl_stream);
1290 }
1291 PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1292 PgStreamInner::Placeholder => unreachable!(),
1293 }
1294
1295 Ok(())
1296 }
1297}
1298
1299fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1300 let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1301
1302 let key_path = &tls_config.key;
1303 let cert_path = &tls_config.cert;
1304
1305 acceptor
1308 .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1309 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1310 acceptor
1311 .set_ca_file(cert_path)
1312 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1313 acceptor
1314 .set_certificate_chain_file(cert_path)
1315 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1316 let acceptor = acceptor.build();
1317
1318 Ok(acceptor.into_context())
1319}
1320
1321pub mod truncated_fmt {
1322 use std::fmt::*;
1323
1324 struct TruncatedFormatter<'a, 'b> {
1325 remaining: usize,
1326 finished: bool,
1327 f: &'a mut Formatter<'b>,
1328 }
1329 impl Write for TruncatedFormatter<'_, '_> {
1330 fn write_str(&mut self, s: &str) -> Result {
1331 if self.finished {
1332 return Ok(());
1333 }
1334
1335 if self.remaining < s.len() {
1336 let actual = s.floor_char_boundary(self.remaining);
1337 self.f.write_str(&s[0..actual])?;
1338 self.remaining -= actual;
1339 self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1340 self.finished = true; } else {
1342 self.f.write_str(s)?;
1343 self.remaining -= s.len();
1344 }
1345 Ok(())
1346 }
1347 }
1348
1349 pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1350
1351 impl<T> Debug for TruncatedFmt<'_, T>
1352 where
1353 T: Debug,
1354 {
1355 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1356 TruncatedFormatter {
1357 remaining: self.1,
1358 finished: false,
1359 f,
1360 }
1361 .write_fmt(format_args!("{:?}", self.0))
1362 }
1363 }
1364
1365 impl<T> Display for TruncatedFmt<'_, T>
1366 where
1367 T: Display,
1368 {
1369 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1370 TruncatedFormatter {
1371 remaining: self.1,
1372 finished: false,
1373 f,
1374 }
1375 .write_fmt(format_args!("{}", self.0))
1376 }
1377 }
1378
1379 #[cfg(test)]
1380 mod tests {
1381 use super::*;
1382
1383 #[test]
1384 fn test_trunc_utf8() {
1385 assert_eq!(
1386 format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1387 "select '...(truncated,14)",
1388 );
1389 }
1390 }
1391}
1392
1393#[cfg(test)]
1394mod tests {
1395 use std::collections::HashSet;
1396
1397 use super::*;
1398
1399 #[test]
1400 fn test_redact_parsable_sql() {
1401 let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1402 let sql = r"
1403 create source temp (k bigint, v varchar) with (
1404 connector = 'datagen',
1405 v1 = 123,
1406 v2 = 'with',
1407 v3 = false,
1408 v4 = '',
1409 ) FORMAT plain ENCODE json (a='1',b='2')
1410 ";
1411 assert_eq!(
1412 redact_sql(sql, keywords),
1413 "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = '[REDACTED]', v3 = false, v4 = '[REDACTED]') FORMAT PLAIN ENCODE JSON (a = '1', b = '[REDACTED]')"
1414 );
1415 }
1416}