1use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
46use crate::net::AddressRef;
47use crate::pg_extended::ResultCache;
48use crate::pg_message::{
49 BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
50 FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
51 FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
52};
53use crate::pg_server::{Session, SessionManager, UserAuthenticator};
54use crate::types::Format;
55
56static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
59 LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
60 Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
61 _ => {
62 if cfg!(debug_assertions) {
63 65536
64 } else {
65 1024
66 }
67 }
68 });
69
70tokio::task_local! {
71 pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
73}
74
75pub struct PgProtocol<S, SM>
78where
79 SM: SessionManager,
80{
81 stream: PgStream<S>,
83 state: PgProtocolState,
85 is_terminate: bool,
87
88 session_mgr: Arc<SM>,
89 session: Option<Arc<SM::Session>>,
90
91 result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
92 unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
93 prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
94 unnamed_portal: Option<<SM::Session as Session>::Portal>,
95 portal_store: HashMap<String, <SM::Session as Session>::Portal>,
96 statement_portal_dependency: HashMap<String, Vec<String>>,
99
100 tls_context: Option<SslContext>,
103
104 tls_config: Option<TlsConfig>,
106
107 ignore_util_sync: bool,
110
111 peer_addr: AddressRef,
113
114 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
115 message_memory_manager: MessageMemoryManagerRef,
116}
117
118#[derive(Debug, Clone)]
120pub struct TlsConfig {
121 pub cert: String,
123 pub key: String,
125 pub enforce_ssl: bool,
127}
128
129impl TlsConfig {
130 pub fn new_default() -> Option<Self> {
131 let cert = std::env::var("RW_SSL_CERT").ok()?;
132 let key = std::env::var("RW_SSL_KEY").ok()?;
133 let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
134 tracing::info!(
135 "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
136 cert,
137 key,
138 enforce_ssl
139 );
140 Some(Self {
141 cert,
142 key,
143 enforce_ssl,
144 })
145 }
146}
147
148impl<S, SM> Drop for PgProtocol<S, SM>
149where
150 SM: SessionManager,
151{
152 fn drop(&mut self) {
153 if let Some(session) = &self.session {
154 self.session_mgr.end_session(session);
156 }
157 }
158}
159
160enum PgProtocolState {
162 Startup,
163 Regular,
164}
165
166pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
170 let without_null = if b.last() == Some(&0) {
171 &b[..b.len() - 1]
172 } else {
173 &b[..]
174 };
175 std::str::from_utf8(without_null)
176}
177
178fn record_sql_in_current_span(
179 sql: &str,
180 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
181) -> String {
182 let mut span = tracing::Span::current();
183 record_sql_in_span(sql, redact_sql_option_keywords, &mut span)
184}
185
186fn record_sql_in_span(
188 sql: &str,
189 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
190 span: &mut tracing::Span,
191) -> String {
192 let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
193 && !keywords.is_empty()
194 {
195 redact_sql(sql, keywords)
196 } else {
197 sql.to_owned()
198 };
199 let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
200 span.record("sql", tracing::field::display(&truncated));
201 truncated.to_string()
202}
203
204fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
206 match Parser::parse_sql(sql) {
207 Ok(sqls) => sqls
208 .into_iter()
209 .map(|sql| sql.to_redacted_string(keywords.clone()))
210 .join(";"),
211 Err(_) => sql.to_owned(),
212 }
213}
214
215#[derive(Clone)]
216pub struct ConnectionContext {
217 pub tls_config: Option<TlsConfig>,
218 pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
219 pub message_memory_manager: MessageMemoryManagerRef,
220}
221
222impl<S, SM> PgProtocol<S, SM>
223where
224 S: PgByteStream,
225 SM: SessionManager,
226{
227 pub fn new(
228 stream: S,
229 session_mgr: Arc<SM>,
230 peer_addr: AddressRef,
231 context: ConnectionContext,
232 ) -> Self {
233 let ConnectionContext {
234 tls_config,
235 redact_sql_option_keywords,
236 message_memory_manager,
237 } = context;
238 Self {
239 stream: PgStream::new(stream),
240 is_terminate: false,
241 state: PgProtocolState::Startup,
242 session_mgr,
243 session: None,
244 tls_context: tls_config
245 .as_ref()
246 .and_then(|e| build_ssl_ctx_from_config(e).ok()),
247 tls_config,
248 result_cache: Default::default(),
249 unnamed_prepare_statement: Default::default(),
250 prepare_statement_store: Default::default(),
251 unnamed_portal: Default::default(),
252 portal_store: Default::default(),
253 statement_portal_dependency: Default::default(),
254 ignore_util_sync: false,
255 peer_addr,
256 redact_sql_option_keywords,
257 message_memory_manager,
258 }
259 }
260
261 pub async fn run(&mut self) {
263 let mut notice_fut = None;
264
265 loop {
266 if notice_fut.is_none()
268 && let Some(session) = self.session.clone()
269 {
270 let mut stream = self.stream.clone();
271 notice_fut = Some(Box::pin(async move {
272 loop {
273 let notice = session.next_notice().await;
274 if let Err(e) = stream.write(BeMessage::NoticeResponse(¬ice)).await {
275 tracing::error!(error = %e.as_report(), notice, "failed to send notice");
276 }
277 }
278 }));
279 }
280
281 let process = std::pin::pin!(async {
283 let (msg, _memory_guard) = match self.read_message().await {
284 Ok(msg) => msg,
285 Err(e) => {
286 tracing::error!(error = %e.as_report(), "error when reading message");
287 return true; }
289 };
290 tracing::trace!(?msg, "received message");
291 self.process(msg).await
292 });
293
294 let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
295 tokio::select! {
296 _ = notice_fut => unreachable!(),
297 terminated = process => terminated,
298 }
299 } else {
300 process.await
301 };
302
303 if terminated {
304 break;
305 }
306 }
307 }
308
309 pub async fn process(&mut self, msg: FeMessage) -> bool {
311 self.do_process(msg).await.is_none() || self.is_terminate
312 }
313
314 fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
322 let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
323 return tracing::Span::none();
324 };
325
326 let mode = match msg {
327 FeMessage::Query(_) => "simple query",
328 FeMessage::Parse(_) => "extended query parse",
329 FeMessage::Execute(_) => "extended query execute",
330 _ => return tracing::Span::none(),
331 };
332
333 let mut span = tracing::info_span!(
334 target: PGWIRE_ROOT_SPAN_TARGET,
335 "handle_query",
336 mode,
337 session_id,
338 sql = tracing::field::Empty,
339 );
340 if let Ok(sql) = msg.get_sql()
341 && let Some(sql) = sql
342 {
343 record_sql_in_span(sql, self.redact_sql_option_keywords.clone(), &mut span);
344 }
345 span
346 }
347
348 async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
352 let span = self.root_span_for_msg(&msg);
353 let weak_session = self
354 .session
355 .as_ref()
356 .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
357
358 let fut = Box::pin(self.do_process_inner(msg));
363
364 let fut = async move {
366 if let Some(session) = weak_session {
367 CURRENT_SESSION.scope(session, fut).await
368 } else {
369 fut.await
370 }
371 };
372
373 let fut = async move {
375 AssertUnwindSafe(fut)
376 .rw_catch_unwind()
377 .await
378 .unwrap_or_else(|payload| {
379 Err(PsqlError::Panic(
380 panic_message::panic_message(&payload).to_owned(),
381 ))
382 })
383 };
384
385 let fut = async move {
387 let period = *SLOW_QUERY_LOG_PERIOD;
388 let mut fut = std::pin::pin!(fut);
389 let mut elapsed = Duration::ZERO;
390
391 loop {
393 match tokio::time::timeout(period, &mut fut).await {
394 Ok(result) => break result,
395 Err(_) => {
396 elapsed += period;
397 tracing::info!(
398 target: PGWIRE_SLOW_QUERY_LOG,
399 elapsed = %format_args!("{}ms", elapsed.as_millis()),
400 "slow query"
401 );
402 }
403 }
404 }
405 };
406
407 let fut = async move {
409 if !tracing::Span::current().is_none() {
410 tracing::info!(
411 target: PGWIRE_QUERY_LOG,
412 status = "started",
413 );
414 }
415
416 let start = Instant::now();
417 let result = fut.await;
418 let elapsed = start.elapsed();
419
420 if let Err(error) = &result {
424 if cfg!(debug_assertions) && !Deployment::current().is_ci() {
425 tracing::error!(error = ?error.as_report(), "error when process message");
431 } else {
432 tracing::error!(error = %error.as_report(), "error when process message");
433 }
434 }
435
436 if !tracing::Span::current().is_none() {
439 tracing::info!(
440 target: PGWIRE_QUERY_LOG,
441 status = if result.is_ok() { "ok" } else { "err" },
442 time = %format_args!("{}ms", elapsed.as_millis()),
443 );
444 }
445
446 result
447 };
448
449 let fut = fut.instrument(span);
451
452 match fut.await {
454 Ok(()) => Some(()),
455 Err(e) => {
456 match e {
457 PsqlError::IoError(io_err) => {
458 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
459 return None;
460 }
461 }
462
463 PsqlError::SslError(_) => {
464 return None;
467 }
468
469 PsqlError::StartupError(_) | PsqlError::PasswordError => {
470 self.stream
471 .write_no_flush(BeMessage::ErrorResponse {
472 error: &e,
473 pretty: false,
476 })
477 .ok()?;
478 let _ = self.stream.flush().await;
479 return None;
480 }
481
482 PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
483 self.stream
484 .write_no_flush(BeMessage::ErrorResponse {
485 error: &e,
486 pretty: true,
487 })
488 .ok()?;
489 self.ready_for_query().ok()?;
490 }
491
492 PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
493 self.stream
494 .write_no_flush(BeMessage::ErrorResponse {
495 error: &e,
496 pretty: true,
497 })
498 .ok()?;
499 let _ = self.stream.flush().await;
500
501 return None;
506 }
507
508 PsqlError::Uncategorized(_)
509 | PsqlError::ExtendedPrepareError(_)
510 | PsqlError::ExtendedExecuteError(_) => {
511 self.stream
512 .write_no_flush(BeMessage::ErrorResponse {
513 error: &e,
514 pretty: true,
515 })
516 .ok()?;
517 }
518 }
519 let _ = self.stream.flush().await;
520 Some(())
521 }
522 }
523 }
524
525 async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
526 if self.ignore_util_sync {
528 if let FeMessage::Sync = msg {
529 } else {
530 tracing::trace!("ignore message {:?} until sync.", msg);
531 return Ok(());
532 }
533 }
534
535 match msg {
536 FeMessage::Gss => self.process_gss_msg().await?,
537 FeMessage::Ssl => self.process_ssl_msg().await?,
538 FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
539 FeMessage::Password(msg) => self.process_password_msg(msg).await?,
540 FeMessage::Query(query_msg) => {
541 let sql = Arc::from(query_msg.get_sql()?);
542 drop(query_msg);
544 self.process_query_msg(sql).await?
545 }
546 FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
547 FeMessage::Terminate => self.process_terminate(),
548 FeMessage::Parse(m) => {
549 if let Err(err) = self.process_parse_msg(m).await {
550 self.ignore_util_sync = true;
551 return Err(err);
552 }
553 }
554 FeMessage::Bind(m) => {
555 if let Err(err) = self.process_bind_msg(m) {
556 self.ignore_util_sync = true;
557 return Err(err);
558 }
559 }
560 FeMessage::Execute(m) => {
561 if let Err(err) = self.process_execute_msg(m).await {
562 self.ignore_util_sync = true;
563 return Err(err);
564 }
565 }
566 FeMessage::Describe(m) => {
567 if let Err(err) = self.process_describe_msg(m) {
568 self.ignore_util_sync = true;
569 return Err(err);
570 }
571 }
572 FeMessage::Sync => {
573 self.ignore_util_sync = false;
574 self.ready_for_query()?
575 }
576 FeMessage::Close(m) => {
577 if let Err(err) = self.process_close_msg(m) {
578 self.ignore_util_sync = true;
579 return Err(err);
580 }
581 }
582 FeMessage::Flush => {
583 if let Err(err) = self.stream.flush().await {
584 self.ignore_util_sync = true;
585 return Err(err.into());
586 }
587 }
588 FeMessage::HealthCheck => self.process_health_check(),
589 FeMessage::ServerThrottle(reason) => match reason {
590 ServerThrottleReason::TooLargeMessage => {
591 return Err(PsqlError::ServerThrottle(format!(
592 "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
593 self.message_memory_manager.max_filter_bytes
594 )));
595 }
596 ServerThrottleReason::TooManyMemoryUsage => {
597 return Err(PsqlError::ServerThrottle(format!(
598 "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
599 self.message_memory_manager.max_running_bytes
600 )));
601 }
602 },
603 }
604 self.stream.flush().await?;
605 Ok(())
606 }
607
608 pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
609 match self.state {
610 PgProtocolState::Startup => self
611 .stream
612 .read_startup()
613 .await
614 .map(|message: FeMessage| (message, None)),
615 PgProtocolState::Regular => {
616 self.stream.read_header().await?;
617 let guard = if let Some(ref header) = self.stream.read_header {
618 let payload_len = std::cmp::max(header.payload_len, 0) as u64;
619 let (reason, guard) = self.message_memory_manager.add(payload_len);
620 if let Some(reason) = reason {
621 drop(guard);
623 self.stream.skip_body().await?;
624 return Ok((FeMessage::ServerThrottle(reason), None));
625 }
626 guard
627 } else {
628 None
629 };
630 let message = self.stream.read_body().await?;
631 Ok((message, guard))
632 }
633 }
634 }
635
636 fn ready_for_query(&mut self) -> io::Result<()> {
638 self.stream.write_no_flush(BeMessage::ReadyForQuery(
639 self.session
640 .as_ref()
641 .map(|s| s.transaction_status())
642 .unwrap_or(TransactionStatus::Idle),
643 ))
644 }
645
646 async fn process_gss_msg(&mut self) -> PsqlResult<()> {
647 self.stream.write(BeMessage::EncryptionResponseNo).await?;
649 Ok(())
650 }
651
652 async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
653 if let Some(context) = self.tls_context.as_ref() {
654 self.stream.write(BeMessage::EncryptionResponseSsl).await?;
657 self.stream.upgrade_to_ssl(context).await?;
658 } else {
659 self.stream.write(BeMessage::EncryptionResponseNo).await?;
661 }
662
663 Ok(())
664 }
665
666 async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
667 if let Some(ref tls_config) = self.tls_config
669 && tls_config.enforce_ssl
670 && !self.stream.is_ssl_connection().await
671 {
672 return Err(PsqlError::StartupError(
673 "SSL connection is required but not established".into(),
674 ));
675 }
676
677 let db_name = msg
678 .config
679 .get("database")
680 .cloned()
681 .unwrap_or_else(|| "dev".to_owned());
682 let user_name = msg
683 .config
684 .get("user")
685 .cloned()
686 .unwrap_or_else(|| "root".to_owned());
687
688 let session = self
689 .session_mgr
690 .connect(&db_name, &user_name, self.peer_addr.clone())
691 .map_err(|e| PsqlError::StartupError(e.into()))?;
692
693 let application_name = msg.config.get("application_name");
694 if let Some(application_name) = application_name {
695 session
696 .set_config("application_name", application_name.clone())
697 .map_err(|e| PsqlError::StartupError(e.into()))?;
698 }
699
700 match session.user_authenticator() {
701 UserAuthenticator::None => {
702 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
703
704 self.stream
707 .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
708
709 self.stream.write_no_flush(BeMessage::ParameterStatus(
710 BeParameterStatusMessage::TimeZone(
711 &session
712 .get_config("timezone")
713 .map_err(|e| PsqlError::StartupError(e.into()))?,
714 ),
715 ))?;
716 self.stream
717 .write_parameter_status_msg_no_flush(&ParameterStatus {
718 application_name: application_name.cloned(),
719 })?;
720 self.ready_for_query()?;
721 }
722 UserAuthenticator::ClearText(_)
723 | UserAuthenticator::OAuth { .. }
724 | UserAuthenticator::Ldap(..) => {
725 self.stream
726 .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
727 }
728 UserAuthenticator::Md5WithSalt { salt, .. } => {
729 self.stream
730 .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
731 }
732 }
733
734 self.session = Some(session);
735 self.state = PgProtocolState::Regular;
736 Ok(())
737 }
738
739 async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
740 let session = self.session.as_ref().unwrap();
741 let authenticator = session.user_authenticator();
742 authenticator.authenticate(&msg.password).await?;
743 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
744 let timezone = session
745 .get_config("timezone")
746 .map_err(|e| PsqlError::StartupError(e.into()))?;
747 self.stream.write_no_flush(BeMessage::ParameterStatus(
748 BeParameterStatusMessage::TimeZone(&timezone),
749 ))?;
750 self.stream
751 .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
752 self.ready_for_query()?;
753 self.state = PgProtocolState::Regular;
754 Ok(())
755 }
756
757 fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
758 let session_id = (m.target_process_id, m.target_secret_key);
759 tracing::trace!("cancel query in session: {:?}", session_id);
760 self.session_mgr.cancel_queries_in_session(session_id);
761 self.session_mgr.cancel_creating_jobs_in_session(session_id);
762 self.is_terminate = true;
763 Ok(())
764 }
765
766 async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
767 let truncated_sql =
768 record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
769 let session = self.session.clone().unwrap();
770
771 session.check_idle_in_transaction_timeout()?;
772 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
774 self.inner_process_query_msg(sql, session.clone()).await
775 }
776
777 async fn inner_process_query_msg(
778 &mut self,
779 sql: Arc<str>,
780 session: Arc<SM::Session>,
781 ) -> PsqlResult<()> {
782 let stmts =
784 Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
785 drop(sql);
787 if stmts.is_empty() {
788 self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
789 }
790
791 for stmt in stmts {
793 self.inner_process_query_msg_one_stmt(stmt, session.clone())
794 .await?;
795 }
796 self.ready_for_query()?;
799 Ok(())
800 }
801
802 async fn inner_process_query_msg_one_stmt(
803 &mut self,
804 stmt: Statement,
805 session: Arc<SM::Session>,
806 ) -> PsqlResult<()> {
807 let session = session.clone();
808
809 let res = session.clone().run_one_query(stmt, Format::Text).await;
811
812 while let Some(notice) = session.next_notice().now_or_never() {
814 self.stream
815 .write_no_flush(BeMessage::NoticeResponse(¬ice))?;
816 }
817
818 let mut res = res.map_err(|e| PsqlError::SimpleQueryError(e.into()))?;
819
820 for notice in res.notices() {
821 self.stream
822 .write_no_flush(BeMessage::NoticeResponse(notice))?;
823 }
824
825 let status = res.status();
826 if let Some(ref application_name) = status.application_name {
827 self.stream.write_no_flush(BeMessage::ParameterStatus(
828 BeParameterStatusMessage::ApplicationName(application_name),
829 ))?;
830 }
831
832 if res.is_copy_query_to_stdout() {
833 self.stream
834 .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
835 let mut count = 0;
836 while let Some(row_set) = res.values_stream().next().await {
837 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
838 for row in row_set {
839 self.stream.write_no_flush(BeMessage::CopyData(&row))?;
840 count += 1;
841 }
842 }
843
844 self.stream.write_no_flush(BeMessage::CopyDone)?;
845
846 res.run_callback().await?;
848
849 self.stream
850 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
851 stmt_type: res.stmt_type(),
852 rows_cnt: count,
853 }))?;
854 } else if res.is_query() {
855 self.stream
856 .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
857
858 let mut rows_cnt = 0;
859
860 while let Some(row_set) = res.values_stream().next().await {
861 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
862 for row in row_set {
863 self.stream.write_no_flush(BeMessage::DataRow(&row))?;
864 rows_cnt += 1;
865 }
866 }
867
868 res.run_callback().await?;
870
871 self.stream
872 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
873 stmt_type: res.stmt_type(),
874 rows_cnt,
875 }))?;
876 } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
877 let first_row_set = res.values_stream().next().await;
878 let first_row_set = match first_row_set {
879 None => {
880 return Err(PsqlError::Uncategorized(
881 anyhow::anyhow!("no affected rows in output").into(),
882 ));
883 }
884 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
885 };
886 let affected_rows_str = first_row_set[0].values()[0]
887 .as_ref()
888 .expect("compute node should return affected rows in output");
889
890 assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
891 let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
892 .unwrap()
893 .parse()
894 .unwrap_or_default();
895
896 res.run_callback().await?;
898
899 self.stream
900 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
901 stmt_type: res.stmt_type(),
902 rows_cnt: affected_rows_cnt,
903 }))?;
904 } else {
905 res.run_callback().await?;
907
908 self.stream
909 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
910 stmt_type: res.stmt_type(),
911 rows_cnt: 0,
912 }))?;
913 }
914
915 Ok(())
916 }
917
918 fn process_terminate(&mut self) {
919 self.is_terminate = true;
920 }
921
922 fn process_health_check(&mut self) {
923 tracing::debug!("health check");
924 self.is_terminate = true;
925 }
926
927 async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
928 let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
929 record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
930 let session = self.session.clone().unwrap();
931 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
932 let type_ids = std::mem::take(&mut msg.type_ids);
933 drop(msg);
935 self.inner_process_parse_msg(session, sql, statement_name, type_ids)
936 .await?;
937 Ok(())
938 }
939
940 async fn inner_process_parse_msg(
941 &mut self,
942 session: Arc<SM::Session>,
943 sql: Arc<str>,
944 statement_name: String,
945 type_ids: Vec<i32>,
946 ) -> PsqlResult<()> {
947 if statement_name.is_empty() {
948 self.unnamed_prepare_statement.take();
951 } else if self.prepare_statement_store.contains_key(&statement_name) {
952 return Err(PsqlError::ExtendedPrepareError(
953 "Duplicated statement name".into(),
954 ));
955 }
956
957 let stmt = {
958 let stmts = Parser::parse_sql(&sql)
959 .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
960 drop(sql);
961 if stmts.len() > 1 {
962 return Err(PsqlError::ExtendedPrepareError(
963 "Only one statement is allowed in extended query mode".into(),
964 ));
965 }
966
967 stmts.into_iter().next()
968 };
969
970 let param_types: Vec<Option<DataType>> = type_ids
971 .iter()
972 .map(|&id| {
973 if id == 0 {
976 Ok(None)
977 } else {
978 DataType::from_oid(id)
979 .map(Some)
980 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
981 }
982 })
983 .try_collect()?;
984
985 let prepare_statement = session
986 .parse(stmt, param_types)
987 .await
988 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))?;
989
990 if statement_name.is_empty() {
991 self.unnamed_prepare_statement.replace(prepare_statement);
992 } else {
993 self.prepare_statement_store
994 .insert(statement_name.clone(), prepare_statement);
995 }
996
997 self.statement_portal_dependency
998 .entry(statement_name)
999 .or_default()
1000 .clear();
1001
1002 self.stream.write_no_flush(BeMessage::ParseComplete)?;
1003 Ok(())
1004 }
1005
1006 fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
1007 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
1008 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1009 let session = self.session.clone().unwrap();
1010
1011 if self.portal_store.contains_key(&portal_name) {
1012 return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
1013 }
1014
1015 let prepare_statement = self.get_statement(&statement_name)?;
1016
1017 let result_formats = msg
1018 .result_format_codes
1019 .iter()
1020 .map(|&format_code| Format::from_i16(format_code))
1021 .try_collect()?;
1022 let param_formats = msg
1023 .param_format_codes
1024 .iter()
1025 .map(|&format_code| Format::from_i16(format_code))
1026 .try_collect()?;
1027
1028 let portal = session
1029 .bind(prepare_statement, msg.params, param_formats, result_formats)
1030 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1031
1032 if portal_name.is_empty() {
1033 self.result_cache.remove(&portal_name);
1034 self.unnamed_portal.replace(portal);
1035 } else {
1036 assert!(
1037 !self.result_cache.contains_key(&portal_name),
1038 "Named portal never can be overridden."
1039 );
1040 self.portal_store.insert(portal_name.clone(), portal);
1041 }
1042
1043 self.statement_portal_dependency
1044 .get_mut(&statement_name)
1045 .unwrap()
1046 .push(portal_name);
1047
1048 self.stream.write_no_flush(BeMessage::BindComplete)?;
1049 Ok(())
1050 }
1051
1052 async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1053 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1054 let row_max = msg.max_rows as usize;
1055 drop(msg);
1056 let session = self.session.clone().unwrap();
1057
1058 match self.result_cache.remove(&portal_name) {
1059 Some(mut result_cache) => {
1060 assert!(self.portal_store.contains_key(&portal_name));
1061
1062 let is_consume_completed =
1063 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1064
1065 if !is_consume_completed {
1066 self.result_cache.insert(portal_name, result_cache);
1067 }
1068 }
1069 _ => {
1070 let portal = self.get_portal(&portal_name)?;
1071 let sql = format!("{}", portal);
1072 let truncated_sql =
1073 record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
1074 drop(sql);
1075
1076 session.check_idle_in_transaction_timeout()?;
1077 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1079 let result = session.clone().execute(portal).await;
1080
1081 let pg_response = result.map_err(|e| PsqlError::ExtendedExecuteError(e.into()))?;
1082 let mut result_cache = ResultCache::new(pg_response);
1083 let is_consume_completed =
1084 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1085 if !is_consume_completed {
1086 self.result_cache.insert(portal_name, result_cache);
1087 }
1088 }
1089 }
1090
1091 Ok(())
1092 }
1093
1094 fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1095 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1096 let session = self.session.clone().unwrap();
1097 assert!(msg.kind == b'S' || msg.kind == b'P');
1101 if msg.kind == b'S' {
1102 let prepare_statement = self.get_statement(&name)?;
1103
1104 let (param_types, row_descriptions) = self
1105 .session
1106 .clone()
1107 .unwrap()
1108 .describe_statement(prepare_statement)
1109 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1110 self.stream.write_no_flush(BeMessage::ParameterDescription(
1111 ¶m_types.iter().map(|t| t.to_oid()).collect_vec(),
1112 ))?;
1113
1114 if row_descriptions.is_empty() {
1115 self.stream.write_no_flush(BeMessage::NoData)?;
1118 } else {
1119 self.stream
1120 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1121 }
1122 } else if msg.kind == b'P' {
1123 let portal = self.get_portal(&name)?;
1124
1125 let row_descriptions = session
1126 .describe_portal(portal)
1127 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1128
1129 if row_descriptions.is_empty() {
1130 self.stream.write_no_flush(BeMessage::NoData)?;
1133 } else {
1134 self.stream
1135 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1136 }
1137 }
1138 Ok(())
1139 }
1140
1141 fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1142 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1143 assert!(msg.kind == b'S' || msg.kind == b'P');
1144 if msg.kind == b'S' {
1145 if name.is_empty() {
1146 self.unnamed_prepare_statement = None;
1147 } else {
1148 self.prepare_statement_store.remove(&name);
1149 }
1150 for portal_name in self
1151 .statement_portal_dependency
1152 .remove(&name)
1153 .unwrap_or_default()
1154 {
1155 self.remove_portal(&portal_name);
1156 }
1157 } else if msg.kind == b'P' {
1158 self.remove_portal(&name);
1159 }
1160 self.stream.write_no_flush(BeMessage::CloseComplete)?;
1161 Ok(())
1162 }
1163
1164 fn remove_portal(&mut self, portal_name: &str) {
1165 if portal_name.is_empty() {
1166 self.unnamed_portal = None;
1167 } else {
1168 self.portal_store.remove(portal_name);
1169 }
1170 self.result_cache.remove(portal_name);
1171 }
1172
1173 fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1174 if portal_name.is_empty() {
1175 Ok(self
1176 .unnamed_portal
1177 .as_ref()
1178 .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1179 .clone())
1180 } else {
1181 Ok(self
1182 .portal_store
1183 .get(portal_name)
1184 .ok_or_else(|| {
1185 PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1186 })?
1187 .clone())
1188 }
1189 }
1190
1191 fn get_statement(
1192 &self,
1193 statement_name: &str,
1194 ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1195 if statement_name.is_empty() {
1196 Ok(self
1197 .unnamed_prepare_statement
1198 .as_ref()
1199 .ok_or_else(|| {
1200 PsqlError::Uncategorized("unnamed prepare statement not found".into())
1201 })?
1202 .clone())
1203 } else {
1204 Ok(self
1205 .prepare_statement_store
1206 .get(statement_name)
1207 .ok_or_else(|| {
1208 PsqlError::Uncategorized(
1209 format!("Prepare statement {} not found", statement_name).into(),
1210 )
1211 })?
1212 .clone())
1213 }
1214 }
1215}
1216
1217enum PgStreamInner<S> {
1218 Placeholder,
1220 Unencrypted(S),
1222 Ssl(SslStream<S>),
1224}
1225
1226pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1228impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1229
1230pub struct PgStream<S> {
1235 stream: Arc<Mutex<PgStreamInner<S>>>,
1237 write_buf: BytesMut,
1239 read_header: Option<FeMessageHeader>,
1240}
1241
1242impl<S> PgStream<S> {
1243 pub fn new(stream: S) -> Self {
1245 const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1246
1247 Self {
1248 stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1249 write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1250 read_header: None,
1251 }
1252 }
1253
1254 async fn is_ssl_connection(&self) -> bool {
1256 let stream = self.stream.lock().await;
1257 matches!(*stream, PgStreamInner::Ssl(_))
1258 }
1259}
1260
1261impl<S> Clone for PgStream<S> {
1262 fn clone(&self) -> Self {
1263 Self {
1264 stream: Arc::clone(&self.stream),
1265 write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1266 read_header: self.read_header.clone(),
1267 }
1268 }
1269}
1270
1271#[derive(Debug, Default, Clone)]
1288pub struct ParameterStatus {
1289 pub application_name: Option<String>,
1290}
1291
1292impl<S> PgStream<S>
1293where
1294 S: PgByteStream,
1295{
1296 async fn read_startup(&mut self) -> io::Result<FeMessage> {
1297 let mut stream = self.stream.lock().await;
1298 match &mut *stream {
1299 PgStreamInner::Placeholder => unreachable!(),
1300 PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1301 PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1302 }
1303 }
1304
1305 async fn read_header(&mut self) -> io::Result<()> {
1306 let mut stream = self.stream.lock().await;
1307 match &mut *stream {
1308 PgStreamInner::Placeholder => unreachable!(),
1309 PgStreamInner::Unencrypted(stream) => {
1310 self.read_header = Some(FeMessage::read_header(stream).await?);
1311 Ok(())
1312 }
1313 PgStreamInner::Ssl(ssl_stream) => {
1314 self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1315 Ok(())
1316 }
1317 }
1318 }
1319
1320 async fn read_body(&mut self) -> io::Result<FeMessage> {
1321 let mut stream = self.stream.lock().await;
1322 let header = self
1323 .read_header
1324 .take()
1325 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1326 match &mut *stream {
1327 PgStreamInner::Placeholder => unreachable!(),
1328 PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1329 PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1330 }
1331 }
1332
1333 async fn skip_body(&mut self) -> io::Result<()> {
1334 let mut stream = self.stream.lock().await;
1335 let header = self
1336 .read_header
1337 .take()
1338 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1339 match &mut *stream {
1340 PgStreamInner::Placeholder => unreachable!(),
1341 PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1342 PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1343 }
1344 }
1345
1346 fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1347 self.write_no_flush(BeMessage::ParameterStatus(
1348 BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1349 ))?;
1350 self.write_no_flush(BeMessage::ParameterStatus(
1351 BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1352 ))?;
1353 self.write_no_flush(BeMessage::ParameterStatus(
1354 BeParameterStatusMessage::ServerVersion(PG_VERSION),
1355 ))?;
1356 if let Some(application_name) = &status.application_name {
1357 self.write_no_flush(BeMessage::ParameterStatus(
1358 BeParameterStatusMessage::ApplicationName(application_name),
1359 ))?;
1360 }
1361 Ok(())
1362 }
1363
1364 pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1365 BeMessage::write(&mut self.write_buf, message)
1366 }
1367
1368 async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1369 self.write_no_flush(message)?;
1370 self.flush().await?;
1371 Ok(())
1372 }
1373
1374 async fn flush(&mut self) -> io::Result<()> {
1375 let mut stream = self.stream.lock().await;
1376 match &mut *stream {
1377 PgStreamInner::Placeholder => unreachable!(),
1378 PgStreamInner::Unencrypted(stream) => {
1379 stream.write_all(&self.write_buf).await?;
1380 stream.flush().await?;
1381 }
1382 PgStreamInner::Ssl(ssl_stream) => {
1383 ssl_stream.write_all(&self.write_buf).await?;
1384 ssl_stream.flush().await?;
1385 }
1386 }
1387 self.write_buf.clear();
1388 Ok(())
1389 }
1390}
1391
1392impl<S> PgStream<S>
1393where
1394 S: PgByteStream,
1395{
1396 async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1398 let mut stream = self.stream.lock().await;
1399
1400 match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1401 PgStreamInner::Unencrypted(unencrypted_stream) => {
1402 let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1403 let mut ssl_stream =
1404 tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1405
1406 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1407 tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1408 let _ = ssl_stream.shutdown().await;
1409 return Err(e.into());
1410 }
1411
1412 *stream = PgStreamInner::Ssl(ssl_stream);
1413 }
1414 PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1415 PgStreamInner::Placeholder => unreachable!(),
1416 }
1417
1418 Ok(())
1419 }
1420}
1421
1422fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1423 let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1424
1425 let key_path = &tls_config.key;
1426 let cert_path = &tls_config.cert;
1427
1428 acceptor
1431 .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1432 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1433 acceptor
1434 .set_ca_file(cert_path)
1435 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1436 acceptor
1437 .set_certificate_chain_file(cert_path)
1438 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1439 let acceptor = acceptor.build();
1440
1441 Ok(acceptor.into_context())
1442}
1443
1444pub mod truncated_fmt {
1445 use std::fmt::*;
1446
1447 struct TruncatedFormatter<'a, 'b> {
1448 remaining: usize,
1449 finished: bool,
1450 f: &'a mut Formatter<'b>,
1451 }
1452 impl Write for TruncatedFormatter<'_, '_> {
1453 fn write_str(&mut self, s: &str) -> Result {
1454 if self.finished {
1455 return Ok(());
1456 }
1457
1458 if self.remaining < s.len() {
1459 let actual = s.floor_char_boundary(self.remaining);
1460 self.f.write_str(&s[0..actual])?;
1461 self.remaining -= actual;
1462 self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1463 self.finished = true; } else {
1465 self.f.write_str(s)?;
1466 self.remaining -= s.len();
1467 }
1468 Ok(())
1469 }
1470 }
1471
1472 pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1473
1474 impl<T> Debug for TruncatedFmt<'_, T>
1475 where
1476 T: Debug,
1477 {
1478 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1479 TruncatedFormatter {
1480 remaining: self.1,
1481 finished: false,
1482 f,
1483 }
1484 .write_fmt(format_args!("{:?}", self.0))
1485 }
1486 }
1487
1488 impl<T> Display for TruncatedFmt<'_, T>
1489 where
1490 T: Display,
1491 {
1492 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1493 TruncatedFormatter {
1494 remaining: self.1,
1495 finished: false,
1496 f,
1497 }
1498 .write_fmt(format_args!("{}", self.0))
1499 }
1500 }
1501
1502 #[cfg(test)]
1503 mod tests {
1504 use super::*;
1505
1506 #[test]
1507 fn test_trunc_utf8() {
1508 assert_eq!(
1509 format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1510 "select '...(truncated,14)",
1511 );
1512 }
1513 }
1514}
1515
1516#[cfg(test)]
1517mod tests {
1518 use std::collections::HashSet;
1519
1520 use super::*;
1521
1522 #[test]
1523 fn test_redact_parsable_sql() {
1524 let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1525 let sql = r"
1526 create source temp (k bigint, v varchar) with (
1527 connector = 'datagen',
1528 v1 = 123,
1529 v2 = 'with',
1530 v3 = false,
1531 v4 = '',
1532 ) FORMAT plain ENCODE json (a='1',b='2')
1533 ";
1534 assert_eq!(
1535 redact_sql(sql, keywords),
1536 "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1537 );
1538 }
1539}