1use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::panic::FutureCatchUnwindExt;
33use risingwave_common::util::query_log::*;
34use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
35use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
36use risingwave_sqlparser::parser::Parser;
37use thiserror_ext::AsReport;
38use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
39use tokio::sync::Mutex;
40use tokio_openssl::SslStream;
41use tracing::Instrument;
42
43use crate::error::{PsqlError, PsqlResult};
44use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
45use crate::net::AddressRef;
46use crate::pg_extended::ResultCache;
47use crate::pg_message::{
48 BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
49 FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
50 FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
51};
52use crate::pg_server::{Session, SessionManager, UserAuthenticator};
53use crate::types::Format;
54
55static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
58 LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
59 Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
60 _ => {
61 if cfg!(debug_assertions) {
62 65536
63 } else {
64 1024
65 }
66 }
67 });
68
69tokio::task_local! {
70 pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
72}
73
74pub struct PgProtocol<S, SM>
77where
78 SM: SessionManager,
79{
80 stream: PgStream<S>,
82 state: PgProtocolState,
84 is_terminate: bool,
86
87 session_mgr: Arc<SM>,
88 session: Option<Arc<SM::Session>>,
89
90 result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
91 unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
92 prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
93 unnamed_portal: Option<<SM::Session as Session>::Portal>,
94 portal_store: HashMap<String, <SM::Session as Session>::Portal>,
95 statement_portal_dependency: HashMap<String, Vec<String>>,
98
99 tls_context: Option<SslContext>,
102
103 ignore_util_sync: bool,
106
107 peer_addr: AddressRef,
109
110 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
111 message_memory_manager: MessageMemoryManagerRef,
112}
113
114#[derive(Debug, Clone)]
116pub struct TlsConfig {
117 pub cert: String,
119 pub key: String,
121}
122
123impl TlsConfig {
124 pub fn new_default() -> Option<Self> {
125 let cert = std::env::var("RW_SSL_CERT").ok()?;
126 let key = std::env::var("RW_SSL_KEY").ok()?;
127 tracing::info!("RW_SSL_CERT={}, RW_SSL_KEY={}", cert, key);
128 Some(Self { cert, key })
129 }
130}
131
132impl<S, SM> Drop for PgProtocol<S, SM>
133where
134 SM: SessionManager,
135{
136 fn drop(&mut self) {
137 if let Some(session) = &self.session {
138 self.session_mgr.end_session(session);
140 }
141 }
142}
143
144enum PgProtocolState {
146 Startup,
147 Regular,
148}
149
150pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
154 let without_null = if b.last() == Some(&0) {
155 &b[..b.len() - 1]
156 } else {
157 &b[..]
158 };
159 std::str::from_utf8(without_null)
160}
161
162fn record_sql_in_span(
164 sql: &str,
165 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
166) -> String {
167 let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
168 && !keywords.is_empty()
169 {
170 redact_sql(sql, keywords)
171 } else {
172 sql.to_owned()
173 };
174 let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
175 tracing::Span::current().record("sql", tracing::field::display(&truncated));
176 truncated.to_string()
177}
178
179fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
181 match Parser::parse_sql(sql) {
182 Ok(sqls) => sqls
183 .into_iter()
184 .map(|sql| sql.to_redacted_string(keywords.clone()))
185 .join(";"),
186 Err(_) => sql.to_owned(),
187 }
188}
189
190#[derive(Clone)]
191pub struct ConnectionContext {
192 pub tls_config: Option<TlsConfig>,
193 pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
194 pub message_memory_manager: MessageMemoryManagerRef,
195}
196
197impl<S, SM> PgProtocol<S, SM>
198where
199 S: PgByteStream,
200 SM: SessionManager,
201{
202 pub fn new(
203 stream: S,
204 session_mgr: Arc<SM>,
205 peer_addr: AddressRef,
206 context: ConnectionContext,
207 ) -> Self {
208 let ConnectionContext {
209 tls_config,
210 redact_sql_option_keywords,
211 message_memory_manager,
212 } = context;
213 Self {
214 stream: PgStream::new(stream),
215 is_terminate: false,
216 state: PgProtocolState::Startup,
217 session_mgr,
218 session: None,
219 tls_context: tls_config
220 .as_ref()
221 .and_then(|e| build_ssl_ctx_from_config(e).ok()),
222 result_cache: Default::default(),
223 unnamed_prepare_statement: Default::default(),
224 prepare_statement_store: Default::default(),
225 unnamed_portal: Default::default(),
226 portal_store: Default::default(),
227 statement_portal_dependency: Default::default(),
228 ignore_util_sync: false,
229 peer_addr,
230 redact_sql_option_keywords,
231 message_memory_manager,
232 }
233 }
234
235 pub async fn run(&mut self) {
237 let mut notice_fut = None;
238
239 loop {
240 if notice_fut.is_none()
242 && let Some(session) = self.session.clone()
243 {
244 let mut stream = self.stream.clone();
245 notice_fut = Some(Box::pin(async move {
246 loop {
247 let notice = session.next_notice().await;
248 if let Err(e) = stream.write(&BeMessage::NoticeResponse(¬ice)).await {
249 tracing::error!(error = %e.as_report(), notice, "failed to send notice");
250 }
251 }
252 }));
253 }
254
255 let process = std::pin::pin!(async {
257 let (msg, _memory_guard) = match self.read_message().await {
258 Ok(msg) => msg,
259 Err(e) => {
260 tracing::error!(error = %e.as_report(), "error when reading message");
261 return true; }
263 };
264 tracing::trace!(?msg, "received message");
265 self.process(msg).await
266 });
267
268 let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
269 tokio::select! {
270 _ = notice_fut => unreachable!(),
271 terminated = process => terminated,
272 }
273 } else {
274 process.await
275 };
276
277 if terminated {
278 break;
279 }
280 }
281 }
282
283 pub async fn process(&mut self, msg: FeMessage) -> bool {
285 self.do_process(msg).await.is_none() || self.is_terminate
286 }
287
288 fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
296 let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
297 return tracing::Span::none();
298 };
299
300 let mode = match msg {
301 FeMessage::Query(_) => "simple query",
302 FeMessage::Parse(_) => "extended query parse",
303 FeMessage::Execute(_) => "extended query execute",
304 _ => return tracing::Span::none(),
305 };
306
307 tracing::info_span!(
308 target: PGWIRE_ROOT_SPAN_TARGET,
309 "handle_query",
310 mode,
311 session_id,
312 sql = tracing::field::Empty, )
314 }
315
316 async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
320 let span = self.root_span_for_msg(&msg);
321 let weak_session = self
322 .session
323 .as_ref()
324 .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
325
326 let fut = Box::pin(self.do_process_inner(msg));
331
332 let fut = async move {
334 if let Some(session) = weak_session {
335 CURRENT_SESSION.scope(session, fut).await
336 } else {
337 fut.await
338 }
339 };
340
341 let fut = async move {
343 AssertUnwindSafe(fut)
344 .rw_catch_unwind()
345 .await
346 .unwrap_or_else(|payload| {
347 Err(PsqlError::Panic(
348 panic_message::panic_message(&payload).to_owned(),
349 ))
350 })
351 };
352
353 let fut = async move {
355 let period = *SLOW_QUERY_LOG_PERIOD;
356 let mut fut = std::pin::pin!(fut);
357 let mut elapsed = Duration::ZERO;
358
359 loop {
361 match tokio::time::timeout(period, &mut fut).await {
362 Ok(result) => break result,
363 Err(_) => {
364 elapsed += period;
365 tracing::info!(
366 target: PGWIRE_SLOW_QUERY_LOG,
367 elapsed = %format_args!("{}ms", elapsed.as_millis()),
368 "slow query"
369 );
370 }
371 }
372 }
373 };
374
375 let fut = async move {
377 let start = Instant::now();
378 let result = fut.await;
379 let elapsed = start.elapsed();
380
381 if let Err(error) = &result {
385 if cfg!(debug_assertions) && !Deployment::current().is_ci() {
386 tracing::error!(error = ?error.as_report(), "error when process message");
392 } else {
393 tracing::error!(error = %error.as_report(), "error when process message");
394 }
395 }
396
397 if !tracing::Span::current().is_none() {
400 tracing::info!(
401 target: PGWIRE_QUERY_LOG,
402 status = if result.is_ok() { "ok" } else { "err" },
403 time = %format_args!("{}ms", elapsed.as_millis()),
404 );
405 }
406
407 result
408 };
409
410 let fut = fut.instrument(span);
412
413 match fut.await {
415 Ok(()) => Some(()),
416 Err(e) => {
417 match e {
418 PsqlError::IoError(io_err) => {
419 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
420 return None;
421 }
422 }
423
424 PsqlError::SslError(_) => {
425 return None;
428 }
429
430 PsqlError::StartupError(_) | PsqlError::PasswordError => {
431 self.stream
432 .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
433 .ok()?;
434 let _ = self.stream.flush().await;
435 return None;
436 }
437
438 PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
439 self.stream
440 .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
441 .ok()?;
442 self.ready_for_query().ok()?;
443 }
444
445 PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
446 self.stream
447 .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
448 .ok()?;
449 let _ = self.stream.flush().await;
450
451 return None;
456 }
457
458 PsqlError::Uncategorized(_)
459 | PsqlError::ExtendedPrepareError(_)
460 | PsqlError::ExtendedExecuteError(_) => {
461 self.stream
462 .write_no_flush(&BeMessage::ErrorResponse(Box::new(e)))
463 .ok()?;
464 }
465 }
466 let _ = self.stream.flush().await;
467 Some(())
468 }
469 }
470 }
471
472 async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
473 if self.ignore_util_sync {
475 if let FeMessage::Sync = msg {
476 } else {
477 tracing::trace!("ignore message {:?} until sync.", msg);
478 return Ok(());
479 }
480 }
481
482 match msg {
483 FeMessage::Gss => self.process_gss_msg().await?,
484 FeMessage::Ssl => self.process_ssl_msg().await?,
485 FeMessage::Startup(msg) => self.process_startup_msg(msg)?,
486 FeMessage::Password(msg) => self.process_password_msg(msg).await?,
487 FeMessage::Query(query_msg) => {
488 let sql = Arc::from(query_msg.get_sql()?);
489 drop(query_msg);
491 self.process_query_msg(sql).await?
492 }
493 FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
494 FeMessage::Terminate => self.process_terminate(),
495 FeMessage::Parse(m) => {
496 if let Err(err) = self.process_parse_msg(m).await {
497 self.ignore_util_sync = true;
498 return Err(err);
499 }
500 }
501 FeMessage::Bind(m) => {
502 if let Err(err) = self.process_bind_msg(m) {
503 self.ignore_util_sync = true;
504 return Err(err);
505 }
506 }
507 FeMessage::Execute(m) => {
508 if let Err(err) = self.process_execute_msg(m).await {
509 self.ignore_util_sync = true;
510 return Err(err);
511 }
512 }
513 FeMessage::Describe(m) => {
514 if let Err(err) = self.process_describe_msg(m) {
515 self.ignore_util_sync = true;
516 return Err(err);
517 }
518 }
519 FeMessage::Sync => {
520 self.ignore_util_sync = false;
521 self.ready_for_query()?
522 }
523 FeMessage::Close(m) => {
524 if let Err(err) = self.process_close_msg(m) {
525 self.ignore_util_sync = true;
526 return Err(err);
527 }
528 }
529 FeMessage::Flush => {
530 if let Err(err) = self.stream.flush().await {
531 self.ignore_util_sync = true;
532 return Err(err.into());
533 }
534 }
535 FeMessage::HealthCheck => self.process_health_check(),
536 FeMessage::ServerThrottle(reason) => match reason {
537 ServerThrottleReason::TooLargeMessage => {
538 return Err(PsqlError::ServerThrottle(format!(
539 "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
540 self.message_memory_manager.max_filter_bytes
541 )));
542 }
543 ServerThrottleReason::TooManyMemoryUsage => {
544 return Err(PsqlError::ServerThrottle(format!(
545 "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
546 self.message_memory_manager.max_running_bytes
547 )));
548 }
549 },
550 }
551 self.stream.flush().await?;
552 Ok(())
553 }
554
555 pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
556 match self.state {
557 PgProtocolState::Startup => self
558 .stream
559 .read_startup()
560 .await
561 .map(|message: FeMessage| (message, None)),
562 PgProtocolState::Regular => {
563 self.stream.read_header().await?;
564 let guard = if let Some(ref header) = self.stream.read_header {
565 let payload_len = std::cmp::max(header.payload_len, 0) as u64;
566 let (reason, guard) = self.message_memory_manager.add(payload_len);
567 if let Some(reason) = reason {
568 drop(guard);
570 self.stream.skip_body().await?;
571 return Ok((FeMessage::ServerThrottle(reason), None));
572 }
573 guard
574 } else {
575 None
576 };
577 let message = self.stream.read_body().await?;
578 Ok((message, guard))
579 }
580 }
581 }
582
583 fn ready_for_query(&mut self) -> io::Result<()> {
585 self.stream.write_no_flush(&BeMessage::ReadyForQuery(
586 self.session
587 .as_ref()
588 .map(|s| s.transaction_status())
589 .unwrap_or(TransactionStatus::Idle),
590 ))
591 }
592
593 async fn process_gss_msg(&mut self) -> PsqlResult<()> {
594 self.stream.write(&BeMessage::EncryptionResponseNo).await?;
596 Ok(())
597 }
598
599 async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
600 if let Some(context) = self.tls_context.as_ref() {
601 self.stream.write(&BeMessage::EncryptionResponseSsl).await?;
604 self.stream.upgrade_to_ssl(context).await?;
605 } else {
606 self.stream.write(&BeMessage::EncryptionResponseNo).await?;
608 }
609
610 Ok(())
611 }
612
613 fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
614 let db_name = msg
615 .config
616 .get("database")
617 .cloned()
618 .unwrap_or_else(|| "dev".to_owned());
619 let user_name = msg
620 .config
621 .get("user")
622 .cloned()
623 .unwrap_or_else(|| "root".to_owned());
624
625 let session = self
626 .session_mgr
627 .connect(&db_name, &user_name, self.peer_addr.clone())
628 .map_err(PsqlError::StartupError)?;
629
630 let application_name = msg.config.get("application_name");
631 if let Some(application_name) = application_name {
632 session
633 .set_config("application_name", application_name.clone())
634 .map_err(PsqlError::StartupError)?;
635 }
636
637 match session.user_authenticator() {
638 UserAuthenticator::None => {
639 self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
640
641 self.stream
644 .write_no_flush(&BeMessage::BackendKeyData(session.id()))?;
645
646 self.stream
647 .write_parameter_status_msg_no_flush(&ParameterStatus {
648 application_name: application_name.cloned(),
649 })?;
650 self.ready_for_query()?;
651 }
652 UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth(_) => {
653 self.stream
654 .write_no_flush(&BeMessage::AuthenticationCleartextPassword)?;
655 }
656 UserAuthenticator::Md5WithSalt { salt, .. } => {
657 self.stream
658 .write_no_flush(&BeMessage::AuthenticationMd5Password(salt))?;
659 }
660 }
661
662 self.session = Some(session);
663 self.state = PgProtocolState::Regular;
664 Ok(())
665 }
666
667 async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
668 let authenticator = self.session.as_ref().unwrap().user_authenticator();
669 authenticator.authenticate(&msg.password).await?;
670 self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
671 self.stream
672 .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
673 self.ready_for_query()?;
674 self.state = PgProtocolState::Regular;
675 Ok(())
676 }
677
678 fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
679 let session_id = (m.target_process_id, m.target_secret_key);
680 tracing::trace!("cancel query in session: {:?}", session_id);
681 self.session_mgr.cancel_queries_in_session(session_id);
682 self.session_mgr.cancel_creating_jobs_in_session(session_id);
683 self.stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
684 Ok(())
685 }
686
687 async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
688 let truncated_sql = record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
689 let session = self.session.clone().unwrap();
690
691 session.check_idle_in_transaction_timeout()?;
692 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
694 self.inner_process_query_msg(sql, session.clone()).await
695 }
696
697 async fn inner_process_query_msg(
698 &mut self,
699 sql: Arc<str>,
700 session: Arc<SM::Session>,
701 ) -> PsqlResult<()> {
702 let stmts =
704 Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
705 drop(sql);
707 if stmts.is_empty() {
708 self.stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
709 }
710
711 for stmt in stmts {
713 self.inner_process_query_msg_one_stmt(stmt, session.clone())
714 .await?;
715 }
716 self.ready_for_query()?;
719 Ok(())
720 }
721
722 async fn inner_process_query_msg_one_stmt(
723 &mut self,
724 stmt: Statement,
725 session: Arc<SM::Session>,
726 ) -> PsqlResult<()> {
727 let session = session.clone();
728
729 let res = session.clone().run_one_query(stmt, Format::Text).await;
731
732 while let Some(notice) = session.next_notice().now_or_never() {
734 self.stream
735 .write_no_flush(&BeMessage::NoticeResponse(¬ice))?;
736 }
737
738 let mut res = res.map_err(PsqlError::SimpleQueryError)?;
739
740 for notice in res.notices() {
741 self.stream
742 .write_no_flush(&BeMessage::NoticeResponse(notice))?;
743 }
744
745 let status = res.status();
746 if let Some(ref application_name) = status.application_name {
747 self.stream.write_no_flush(&BeMessage::ParameterStatus(
748 BeParameterStatusMessage::ApplicationName(application_name),
749 ))?;
750 }
751
752 if res.is_query() {
753 self.stream
754 .write_no_flush(&BeMessage::RowDescription(&res.row_desc()))?;
755
756 let mut rows_cnt = 0;
757
758 while let Some(row_set) = res.values_stream().next().await {
759 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
760 for row in row_set {
761 self.stream.write_no_flush(&BeMessage::DataRow(&row))?;
762 rows_cnt += 1;
763 }
764 }
765
766 res.run_callback().await?;
768
769 self.stream
770 .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
771 stmt_type: res.stmt_type(),
772 rows_cnt,
773 }))?;
774 } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
775 let first_row_set = res.values_stream().next().await;
776 let first_row_set = match first_row_set {
777 None => {
778 return Err(PsqlError::Uncategorized(
779 anyhow::anyhow!("no affected rows in output").into(),
780 ));
781 }
782 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
783 };
784 let affected_rows_str = first_row_set[0].values()[0]
785 .as_ref()
786 .expect("compute node should return affected rows in output");
787
788 assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
789 let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
790 .unwrap()
791 .parse()
792 .unwrap_or_default();
793
794 res.run_callback().await?;
796
797 self.stream
798 .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
799 stmt_type: res.stmt_type(),
800 rows_cnt: affected_rows_cnt,
801 }))?;
802 } else {
803 res.run_callback().await?;
805
806 self.stream
807 .write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
808 stmt_type: res.stmt_type(),
809 rows_cnt: 0,
810 }))?;
811 }
812
813 Ok(())
814 }
815
816 fn process_terminate(&mut self) {
817 self.is_terminate = true;
818 }
819
820 fn process_health_check(&mut self) {
821 tracing::debug!("health check");
822 self.is_terminate = true;
823 }
824
825 async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
826 let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
827 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
828 let session = self.session.clone().unwrap();
829 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
830 let type_ids = std::mem::take(&mut msg.type_ids);
831 drop(msg);
833 self.inner_process_parse_msg(session, sql, statement_name, type_ids)
834 .await?;
835 Ok(())
836 }
837
838 async fn inner_process_parse_msg(
839 &mut self,
840 session: Arc<SM::Session>,
841 sql: Arc<str>,
842 statement_name: String,
843 type_ids: Vec<i32>,
844 ) -> PsqlResult<()> {
845 if statement_name.is_empty() {
846 self.unnamed_prepare_statement.take();
849 } else if self.prepare_statement_store.contains_key(&statement_name) {
850 return Err(PsqlError::ExtendedPrepareError(
851 "Duplicated statement name".into(),
852 ));
853 }
854
855 let stmt = {
856 let stmts = Parser::parse_sql(&sql)
857 .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
858 drop(sql);
859 if stmts.len() > 1 {
860 return Err(PsqlError::ExtendedPrepareError(
861 "Only one statement is allowed in extended query mode".into(),
862 ));
863 }
864
865 stmts.into_iter().next()
866 };
867
868 let param_types: Vec<Option<DataType>> = type_ids
869 .iter()
870 .map(|&id| {
871 if id == 0 {
874 Ok(None)
875 } else {
876 DataType::from_oid(id)
877 .map(Some)
878 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
879 }
880 })
881 .try_collect()?;
882
883 let prepare_statement = session
884 .parse(stmt, param_types)
885 .await
886 .map_err(PsqlError::ExtendedPrepareError)?;
887
888 if statement_name.is_empty() {
889 self.unnamed_prepare_statement.replace(prepare_statement);
890 } else {
891 self.prepare_statement_store
892 .insert(statement_name.clone(), prepare_statement);
893 }
894
895 self.statement_portal_dependency
896 .entry(statement_name)
897 .or_default()
898 .clear();
899
900 self.stream.write_no_flush(&BeMessage::ParseComplete)?;
901 Ok(())
902 }
903
904 fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
905 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
906 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
907 let session = self.session.clone().unwrap();
908
909 if self.portal_store.contains_key(&portal_name) {
910 return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
911 }
912
913 let prepare_statement = self.get_statement(&statement_name)?;
914
915 let result_formats = msg
916 .result_format_codes
917 .iter()
918 .map(|&format_code| Format::from_i16(format_code))
919 .try_collect()?;
920 let param_formats = msg
921 .param_format_codes
922 .iter()
923 .map(|&format_code| Format::from_i16(format_code))
924 .try_collect()?;
925
926 let portal = session
927 .bind(prepare_statement, msg.params, param_formats, result_formats)
928 .map_err(PsqlError::Uncategorized)?;
929
930 if portal_name.is_empty() {
931 self.result_cache.remove(&portal_name);
932 self.unnamed_portal.replace(portal);
933 } else {
934 assert!(
935 !self.result_cache.contains_key(&portal_name),
936 "Named portal never can be overridden."
937 );
938 self.portal_store.insert(portal_name.clone(), portal);
939 }
940
941 self.statement_portal_dependency
942 .get_mut(&statement_name)
943 .unwrap()
944 .push(portal_name);
945
946 self.stream.write_no_flush(&BeMessage::BindComplete)?;
947 Ok(())
948 }
949
950 async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
951 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
952 let row_max = msg.max_rows as usize;
953 drop(msg);
954 let session = self.session.clone().unwrap();
955
956 match self.result_cache.remove(&portal_name) {
957 Some(mut result_cache) => {
958 assert!(self.portal_store.contains_key(&portal_name));
959
960 let is_cosume_completed =
961 result_cache.consume::<S>(row_max, &mut self.stream).await?;
962
963 if !is_cosume_completed {
964 self.result_cache.insert(portal_name, result_cache);
965 }
966 }
967 _ => {
968 let portal = self.get_portal(&portal_name)?;
969 let sql = format!("{}", portal);
970 let truncated_sql =
971 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
972 drop(sql);
973
974 session.check_idle_in_transaction_timeout()?;
975 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
977 let result = session.clone().execute(portal).await;
978
979 let pg_response = result.map_err(PsqlError::ExtendedExecuteError)?;
980 let mut result_cache = ResultCache::new(pg_response);
981 let is_consume_completed =
982 result_cache.consume::<S>(row_max, &mut self.stream).await?;
983 if !is_consume_completed {
984 self.result_cache.insert(portal_name, result_cache);
985 }
986 }
987 }
988
989 Ok(())
990 }
991
992 fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
993 let name = cstr_to_str(&msg.name).unwrap().to_owned();
994 let session = self.session.clone().unwrap();
995 assert!(msg.kind == b'S' || msg.kind == b'P');
999 if msg.kind == b'S' {
1000 let prepare_statement = self.get_statement(&name)?;
1001
1002 let (param_types, row_descriptions) = self
1003 .session
1004 .clone()
1005 .unwrap()
1006 .describe_statement(prepare_statement)
1007 .map_err(PsqlError::Uncategorized)?;
1008 self.stream
1009 .write_no_flush(&BeMessage::ParameterDescription(
1010 ¶m_types.iter().map(|t| t.to_oid()).collect_vec(),
1011 ))?;
1012
1013 if row_descriptions.is_empty() {
1014 self.stream.write_no_flush(&BeMessage::NoData)?;
1017 } else {
1018 self.stream
1019 .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?;
1020 }
1021 } else if msg.kind == b'P' {
1022 let portal = self.get_portal(&name)?;
1023
1024 let row_descriptions = session
1025 .describe_portal(portal)
1026 .map_err(PsqlError::Uncategorized)?;
1027
1028 if row_descriptions.is_empty() {
1029 self.stream.write_no_flush(&BeMessage::NoData)?;
1032 } else {
1033 self.stream
1034 .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?;
1035 }
1036 }
1037 Ok(())
1038 }
1039
1040 fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1041 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1042 assert!(msg.kind == b'S' || msg.kind == b'P');
1043 if msg.kind == b'S' {
1044 if name.is_empty() {
1045 self.unnamed_prepare_statement = None;
1046 } else {
1047 self.prepare_statement_store.remove(&name);
1048 }
1049 for portal_name in self
1050 .statement_portal_dependency
1051 .remove(&name)
1052 .unwrap_or_default()
1053 {
1054 self.remove_portal(&portal_name);
1055 }
1056 } else if msg.kind == b'P' {
1057 self.remove_portal(&name);
1058 }
1059 self.stream.write_no_flush(&BeMessage::CloseComplete)?;
1060 Ok(())
1061 }
1062
1063 fn remove_portal(&mut self, portal_name: &str) {
1064 if portal_name.is_empty() {
1065 self.unnamed_portal = None;
1066 } else {
1067 self.portal_store.remove(portal_name);
1068 }
1069 self.result_cache.remove(portal_name);
1070 }
1071
1072 fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1073 if portal_name.is_empty() {
1074 Ok(self
1075 .unnamed_portal
1076 .as_ref()
1077 .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1078 .clone())
1079 } else {
1080 Ok(self
1081 .portal_store
1082 .get(portal_name)
1083 .ok_or_else(|| {
1084 PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1085 })?
1086 .clone())
1087 }
1088 }
1089
1090 fn get_statement(
1091 &self,
1092 statement_name: &str,
1093 ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1094 if statement_name.is_empty() {
1095 Ok(self
1096 .unnamed_prepare_statement
1097 .as_ref()
1098 .ok_or_else(|| {
1099 PsqlError::Uncategorized("unnamed prepare statement not found".into())
1100 })?
1101 .clone())
1102 } else {
1103 Ok(self
1104 .prepare_statement_store
1105 .get(statement_name)
1106 .ok_or_else(|| {
1107 PsqlError::Uncategorized(
1108 format!("Prepare statement {} not found", statement_name).into(),
1109 )
1110 })?
1111 .clone())
1112 }
1113 }
1114}
1115
1116enum PgStreamInner<S> {
1117 Placeholder,
1119 Unencrypted(S),
1121 Ssl(SslStream<S>),
1123}
1124
1125pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1127impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1128
1129pub struct PgStream<S> {
1134 stream: Arc<Mutex<PgStreamInner<S>>>,
1136 write_buf: BytesMut,
1138 read_header: Option<FeMessageHeader>,
1139}
1140
1141impl<S> PgStream<S> {
1142 pub fn new(stream: S) -> Self {
1144 const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1145
1146 Self {
1147 stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1148 write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1149 read_header: None,
1150 }
1151 }
1152}
1153
1154impl<S> Clone for PgStream<S> {
1155 fn clone(&self) -> Self {
1156 Self {
1157 stream: Arc::clone(&self.stream),
1158 write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1159 read_header: self.read_header.clone(),
1160 }
1161 }
1162}
1163
1164#[derive(Debug, Default, Clone)]
1181pub struct ParameterStatus {
1182 pub application_name: Option<String>,
1183}
1184
1185impl<S> PgStream<S>
1186where
1187 S: PgByteStream,
1188{
1189 async fn read_startup(&mut self) -> io::Result<FeMessage> {
1190 let mut stream = self.stream.lock().await;
1191 match &mut *stream {
1192 PgStreamInner::Placeholder => unreachable!(),
1193 PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1194 PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1195 }
1196 }
1197
1198 async fn read_header(&mut self) -> io::Result<()> {
1199 let mut stream = self.stream.lock().await;
1200 match &mut *stream {
1201 PgStreamInner::Placeholder => unreachable!(),
1202 PgStreamInner::Unencrypted(stream) => {
1203 self.read_header = Some(FeMessage::read_header(stream).await?);
1204 Ok(())
1205 }
1206 PgStreamInner::Ssl(ssl_stream) => {
1207 self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1208 Ok(())
1209 }
1210 }
1211 }
1212
1213 async fn read_body(&mut self) -> io::Result<FeMessage> {
1214 let mut stream = self.stream.lock().await;
1215 let header = self
1216 .read_header
1217 .take()
1218 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1219 match &mut *stream {
1220 PgStreamInner::Placeholder => unreachable!(),
1221 PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1222 PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1223 }
1224 }
1225
1226 async fn skip_body(&mut self) -> io::Result<()> {
1227 let mut stream = self.stream.lock().await;
1228 let header = self
1229 .read_header
1230 .take()
1231 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1232 match &mut *stream {
1233 PgStreamInner::Placeholder => unreachable!(),
1234 PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1235 PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1236 }
1237 }
1238
1239 fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1240 self.write_no_flush(&BeMessage::ParameterStatus(
1241 BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1242 ))?;
1243 self.write_no_flush(&BeMessage::ParameterStatus(
1244 BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1245 ))?;
1246 self.write_no_flush(&BeMessage::ParameterStatus(
1247 BeParameterStatusMessage::ServerVersion(PG_VERSION),
1248 ))?;
1249 if let Some(application_name) = &status.application_name {
1250 self.write_no_flush(&BeMessage::ParameterStatus(
1251 BeParameterStatusMessage::ApplicationName(application_name),
1252 ))?;
1253 }
1254 Ok(())
1255 }
1256
1257 pub fn write_no_flush(&mut self, message: &BeMessage<'_>) -> io::Result<()> {
1258 BeMessage::write(&mut self.write_buf, message)
1259 }
1260
1261 async fn write(&mut self, message: &BeMessage<'_>) -> io::Result<()> {
1262 self.write_no_flush(message)?;
1263 self.flush().await?;
1264 Ok(())
1265 }
1266
1267 async fn flush(&mut self) -> io::Result<()> {
1268 let mut stream = self.stream.lock().await;
1269 match &mut *stream {
1270 PgStreamInner::Placeholder => unreachable!(),
1271 PgStreamInner::Unencrypted(stream) => {
1272 stream.write_all(&self.write_buf).await?;
1273 stream.flush().await?;
1274 }
1275 PgStreamInner::Ssl(ssl_stream) => {
1276 ssl_stream.write_all(&self.write_buf).await?;
1277 ssl_stream.flush().await?;
1278 }
1279 }
1280 self.write_buf.clear();
1281 Ok(())
1282 }
1283}
1284
1285impl<S> PgStream<S>
1286where
1287 S: PgByteStream,
1288{
1289 async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1291 let mut stream = self.stream.lock().await;
1292
1293 match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1294 PgStreamInner::Unencrypted(unencrypted_stream) => {
1295 let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1296 let mut ssl_stream =
1297 tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1298
1299 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1300 tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1301 let _ = ssl_stream.shutdown().await;
1302 return Err(e.into());
1303 }
1304
1305 *stream = PgStreamInner::Ssl(ssl_stream);
1306 }
1307 PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1308 PgStreamInner::Placeholder => unreachable!(),
1309 }
1310
1311 Ok(())
1312 }
1313}
1314
1315fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1316 let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1317
1318 let key_path = &tls_config.key;
1319 let cert_path = &tls_config.cert;
1320
1321 acceptor
1324 .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1325 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1326 acceptor
1327 .set_ca_file(cert_path)
1328 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1329 acceptor
1330 .set_certificate_chain_file(cert_path)
1331 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1332 let acceptor = acceptor.build();
1333
1334 Ok(acceptor.into_context())
1335}
1336
1337pub mod truncated_fmt {
1338 use std::fmt::*;
1339
1340 struct TruncatedFormatter<'a, 'b> {
1341 remaining: usize,
1342 finished: bool,
1343 f: &'a mut Formatter<'b>,
1344 }
1345 impl Write for TruncatedFormatter<'_, '_> {
1346 fn write_str(&mut self, s: &str) -> Result {
1347 if self.finished {
1348 return Ok(());
1349 }
1350
1351 if self.remaining < s.len() {
1352 let actual = s.floor_char_boundary(self.remaining);
1353 self.f.write_str(&s[0..actual])?;
1354 self.remaining -= actual;
1355 self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1356 self.finished = true; } else {
1358 self.f.write_str(s)?;
1359 self.remaining -= s.len();
1360 }
1361 Ok(())
1362 }
1363 }
1364
1365 pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1366
1367 impl<T> Debug for TruncatedFmt<'_, T>
1368 where
1369 T: Debug,
1370 {
1371 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1372 TruncatedFormatter {
1373 remaining: self.1,
1374 finished: false,
1375 f,
1376 }
1377 .write_fmt(format_args!("{:?}", self.0))
1378 }
1379 }
1380
1381 impl<T> Display for TruncatedFmt<'_, T>
1382 where
1383 T: Display,
1384 {
1385 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1386 TruncatedFormatter {
1387 remaining: self.1,
1388 finished: false,
1389 f,
1390 }
1391 .write_fmt(format_args!("{}", self.0))
1392 }
1393 }
1394
1395 #[cfg(test)]
1396 mod tests {
1397 use super::*;
1398
1399 #[test]
1400 fn test_trunc_utf8() {
1401 assert_eq!(
1402 format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1403 "select '...(truncated,14)",
1404 );
1405 }
1406 }
1407}
1408
1409#[cfg(test)]
1410mod tests {
1411 use std::collections::HashSet;
1412
1413 use super::*;
1414
1415 #[test]
1416 fn test_redact_parsable_sql() {
1417 let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1418 let sql = r"
1419 create source temp (k bigint, v varchar) with (
1420 connector = 'datagen',
1421 v1 = 123,
1422 v2 = 'with',
1423 v3 = false,
1424 v4 = '',
1425 ) FORMAT plain ENCODE json (a='1',b='2')
1426 ";
1427 assert_eq!(
1428 redact_sql(sql, keywords),
1429 "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1430 );
1431 }
1432}