1use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
46use crate::net::AddressRef;
47use crate::pg_extended::ResultCache;
48use crate::pg_message::{
49 BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
50 FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
51 FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
52};
53use crate::pg_server::{Session, SessionManager, UserAuthenticator};
54use crate::types::Format;
55
56static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
59 LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
60 Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
61 _ => 65536,
62 });
63
64tokio::task_local! {
65 pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
67}
68
69pub struct PgProtocol<S, SM>
72where
73 SM: SessionManager,
74{
75 stream: PgStream<S>,
77 state: PgProtocolState,
79 is_terminate: bool,
81
82 session_mgr: Arc<SM>,
83 session: Option<Arc<SM::Session>>,
84
85 result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
86 unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
87 prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
88 unnamed_portal: Option<<SM::Session as Session>::Portal>,
89 portal_store: HashMap<String, <SM::Session as Session>::Portal>,
90 statement_portal_dependency: HashMap<String, Vec<String>>,
93
94 tls_context: Option<SslContext>,
97
98 tls_config: Option<TlsConfig>,
100
101 ignore_util_sync: bool,
104
105 peer_addr: AddressRef,
107
108 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
109 message_memory_manager: MessageMemoryManagerRef,
110}
111
112#[derive(Debug, Clone)]
114pub struct TlsConfig {
115 pub cert: String,
117 pub key: String,
119 pub enforce_ssl: bool,
121}
122
123impl TlsConfig {
124 pub fn new_default() -> Option<Self> {
125 let cert = std::env::var("RW_SSL_CERT").ok()?;
126 let key = std::env::var("RW_SSL_KEY").ok()?;
127 let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
128 tracing::info!(
129 "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
130 cert,
131 key,
132 enforce_ssl
133 );
134 Some(Self {
135 cert,
136 key,
137 enforce_ssl,
138 })
139 }
140}
141
142impl<S, SM> Drop for PgProtocol<S, SM>
143where
144 SM: SessionManager,
145{
146 fn drop(&mut self) {
147 if let Some(session) = &self.session {
148 self.session_mgr.end_session(session);
150 }
151 }
152}
153
154enum PgProtocolState {
156 Startup,
157 Regular,
158}
159
160pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
164 let without_null = if b.last() == Some(&0) {
165 &b[..b.len() - 1]
166 } else {
167 &b[..]
168 };
169 std::str::from_utf8(without_null)
170}
171
172fn record_sql_in_current_span(
173 sql: &str,
174 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
175) -> String {
176 let mut span = tracing::Span::current();
177 record_sql_in_span(sql, redact_sql_option_keywords, &mut span)
178}
179
180fn record_sql_in_span(
182 sql: &str,
183 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
184 span: &mut tracing::Span,
185) -> String {
186 let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
187 && !keywords.is_empty()
188 {
189 redact_sql(sql, keywords)
190 } else {
191 sql.to_owned()
192 };
193 let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
194 span.record("sql", tracing::field::display(&truncated));
195 truncated.to_string()
196}
197
198fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
200 match Parser::parse_sql(sql) {
201 Ok(sqls) => sqls
202 .into_iter()
203 .map(|sql| sql.to_redacted_string(keywords.clone()))
204 .join(";"),
205 Err(_) => sql.to_owned(),
206 }
207}
208
209#[derive(Clone)]
210pub struct ConnectionContext {
211 pub tls_config: Option<TlsConfig>,
212 pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
213 pub message_memory_manager: MessageMemoryManagerRef,
214}
215
216impl<S, SM> PgProtocol<S, SM>
217where
218 S: PgByteStream,
219 SM: SessionManager,
220{
221 pub fn new(
222 stream: S,
223 session_mgr: Arc<SM>,
224 peer_addr: AddressRef,
225 context: ConnectionContext,
226 ) -> Self {
227 let ConnectionContext {
228 tls_config,
229 redact_sql_option_keywords,
230 message_memory_manager,
231 } = context;
232 Self {
233 stream: PgStream::new(stream),
234 is_terminate: false,
235 state: PgProtocolState::Startup,
236 session_mgr,
237 session: None,
238 tls_context: tls_config
239 .as_ref()
240 .and_then(|e| build_ssl_ctx_from_config(e).ok()),
241 tls_config,
242 result_cache: Default::default(),
243 unnamed_prepare_statement: Default::default(),
244 prepare_statement_store: Default::default(),
245 unnamed_portal: Default::default(),
246 portal_store: Default::default(),
247 statement_portal_dependency: Default::default(),
248 ignore_util_sync: false,
249 peer_addr,
250 redact_sql_option_keywords,
251 message_memory_manager,
252 }
253 }
254
255 pub async fn run(&mut self) {
257 let mut notice_fut = None;
258
259 loop {
260 if notice_fut.is_none()
262 && let Some(session) = self.session.clone()
263 {
264 let mut stream = self.stream.clone();
265 notice_fut = Some(Box::pin(async move {
266 loop {
267 let notice = session.next_notice().await;
268 if let Err(e) = stream.write(BeMessage::NoticeResponse(¬ice)).await {
269 tracing::error!(error = %e.as_report(), notice, "failed to send notice");
270 }
271 }
272 }));
273 }
274
275 let process = std::pin::pin!(async {
277 let (msg, _memory_guard) = match self.read_message().await {
278 Ok(msg) => msg,
279 Err(e) => {
280 tracing::error!(error = %e.as_report(), "error when reading message");
281 return true; }
283 };
284 tracing::trace!(?msg, "received message");
285 self.process(msg).await
286 });
287
288 let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
289 tokio::select! {
290 _ = notice_fut => unreachable!(),
291 terminated = process => terminated,
292 }
293 } else {
294 process.await
295 };
296
297 if terminated {
298 break;
299 }
300 }
301 }
302
303 pub async fn process(&mut self, msg: FeMessage) -> bool {
305 self.do_process(msg).await.is_none() || self.is_terminate
306 }
307
308 fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
316 let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
317 return tracing::Span::none();
318 };
319
320 let mode = match msg {
321 FeMessage::Query(_) => "simple query",
322 FeMessage::Parse(_) => "extended query parse",
323 FeMessage::Execute(_) => "extended query execute",
324 _ => return tracing::Span::none(),
325 };
326
327 let mut span = tracing::info_span!(
328 target: PGWIRE_ROOT_SPAN_TARGET,
329 "handle_query",
330 mode,
331 session_id,
332 sql = tracing::field::Empty,
333 );
334 if let Ok(sql) = msg.get_sql()
335 && let Some(sql) = sql
336 {
337 record_sql_in_span(sql, self.redact_sql_option_keywords.clone(), &mut span);
338 }
339 span
340 }
341
342 async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
346 let span = self.root_span_for_msg(&msg);
347 let weak_session = self
348 .session
349 .as_ref()
350 .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
351
352 let fut = Box::pin(self.do_process_inner(msg));
357
358 let fut = async move {
360 if let Some(session) = weak_session {
361 CURRENT_SESSION.scope(session, fut).await
362 } else {
363 fut.await
364 }
365 };
366
367 let fut = async move {
369 AssertUnwindSafe(fut)
370 .rw_catch_unwind()
371 .await
372 .unwrap_or_else(|payload| {
373 Err(PsqlError::Panic(
374 panic_message::panic_message(&payload).to_owned(),
375 ))
376 })
377 };
378
379 let fut = async move {
381 let period = *SLOW_QUERY_LOG_PERIOD;
382 let mut fut = std::pin::pin!(fut);
383 let mut elapsed = Duration::ZERO;
384
385 loop {
387 match tokio::time::timeout(period, &mut fut).await {
388 Ok(result) => break result,
389 Err(_) => {
390 elapsed += period;
391 tracing::info!(
392 target: PGWIRE_SLOW_QUERY_LOG,
393 elapsed = %format_args!("{}ms", elapsed.as_millis()),
394 "slow query"
395 );
396 }
397 }
398 }
399 };
400
401 let fut = async move {
403 if !tracing::Span::current().is_none() {
404 tracing::info!(
405 target: PGWIRE_QUERY_LOG,
406 status = "started",
407 );
408 }
409
410 let start = Instant::now();
411 let result = fut.await;
412 let elapsed = start.elapsed();
413
414 if let Err(error) = &result {
418 if cfg!(debug_assertions) && !Deployment::current().is_ci() {
419 tracing::error!(error = ?error.as_report(), "error when process message");
425 } else {
426 tracing::error!(error = %error.as_report(), "error when process message");
427 }
428 }
429
430 if !tracing::Span::current().is_none() {
433 tracing::info!(
434 target: PGWIRE_QUERY_LOG,
435 status = if result.is_ok() { "ok" } else { "err" },
436 time = %format_args!("{}ms", elapsed.as_millis()),
437 );
438 }
439
440 result
441 };
442
443 let fut = fut.instrument(span);
445
446 match fut.await {
448 Ok(()) => Some(()),
449 Err(e) => {
450 match e {
451 PsqlError::IoError(io_err) => {
452 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
453 return None;
454 }
455 }
456
457 PsqlError::SslError(_) => {
458 return None;
461 }
462
463 PsqlError::StartupError(_) | PsqlError::PasswordError => {
464 self.stream
465 .write_no_flush(BeMessage::ErrorResponse {
466 error: &e,
467 pretty: false,
470 })
471 .ok()?;
472 let _ = self.stream.flush().await;
473 return None;
474 }
475
476 PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
477 self.stream
478 .write_no_flush(BeMessage::ErrorResponse {
479 error: &e,
480 pretty: true,
481 })
482 .ok()?;
483 self.ready_for_query().ok()?;
484 }
485
486 PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
487 self.stream
488 .write_no_flush(BeMessage::ErrorResponse {
489 error: &e,
490 pretty: true,
491 })
492 .ok()?;
493 let _ = self.stream.flush().await;
494
495 return None;
500 }
501
502 PsqlError::Uncategorized(_)
503 | PsqlError::ExtendedPrepareError(_)
504 | PsqlError::ExtendedExecuteError(_) => {
505 self.stream
506 .write_no_flush(BeMessage::ErrorResponse {
507 error: &e,
508 pretty: true,
509 })
510 .ok()?;
511 }
512 }
513 let _ = self.stream.flush().await;
514 Some(())
515 }
516 }
517 }
518
519 async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
520 if self.ignore_util_sync {
522 if let FeMessage::Sync = msg {
523 } else {
524 tracing::trace!("ignore message {:?} until sync.", msg);
525 return Ok(());
526 }
527 }
528
529 match msg {
530 FeMessage::Gss => self.process_gss_msg().await?,
531 FeMessage::Ssl => self.process_ssl_msg().await?,
532 FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
533 FeMessage::Password(msg) => self.process_password_msg(msg).await?,
534 FeMessage::Query(query_msg) => {
535 let sql = Arc::from(query_msg.get_sql()?);
536 drop(query_msg);
538 self.process_query_msg(sql).await?
539 }
540 FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
541 FeMessage::Terminate => self.process_terminate(),
542 FeMessage::Parse(m) => {
543 if let Err(err) = self.process_parse_msg(m).await {
544 self.ignore_util_sync = true;
545 return Err(err);
546 }
547 }
548 FeMessage::Bind(m) => {
549 if let Err(err) = self.process_bind_msg(m) {
550 self.ignore_util_sync = true;
551 return Err(err);
552 }
553 }
554 FeMessage::Execute(m) => {
555 if let Err(err) = self.process_execute_msg(m).await {
556 self.ignore_util_sync = true;
557 return Err(err);
558 }
559 }
560 FeMessage::Describe(m) => {
561 if let Err(err) = self.process_describe_msg(m) {
562 self.ignore_util_sync = true;
563 return Err(err);
564 }
565 }
566 FeMessage::Sync => {
567 self.ignore_util_sync = false;
568 self.ready_for_query()?
569 }
570 FeMessage::Close(m) => {
571 if let Err(err) = self.process_close_msg(m) {
572 self.ignore_util_sync = true;
573 return Err(err);
574 }
575 }
576 FeMessage::Flush => {
577 if let Err(err) = self.stream.flush().await {
578 self.ignore_util_sync = true;
579 return Err(err.into());
580 }
581 }
582 FeMessage::HealthCheck => self.process_health_check(),
583 FeMessage::ServerThrottle(reason) => match reason {
584 ServerThrottleReason::TooLargeMessage => {
585 return Err(PsqlError::ServerThrottle(format!(
586 "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
587 self.message_memory_manager.max_filter_bytes
588 )));
589 }
590 ServerThrottleReason::TooManyMemoryUsage => {
591 return Err(PsqlError::ServerThrottle(format!(
592 "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
593 self.message_memory_manager.max_running_bytes
594 )));
595 }
596 },
597 }
598 self.stream.flush().await?;
599 Ok(())
600 }
601
602 pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
603 match self.state {
604 PgProtocolState::Startup => self
605 .stream
606 .read_startup()
607 .await
608 .map(|message: FeMessage| (message, None)),
609 PgProtocolState::Regular => {
610 self.stream.read_header().await?;
611 let guard = if let Some(ref header) = self.stream.read_header {
612 let payload_len = std::cmp::max(header.payload_len, 0) as u64;
613 let (reason, guard) = self.message_memory_manager.add(payload_len);
614 if let Some(reason) = reason {
615 drop(guard);
617 self.stream.skip_body().await?;
618 return Ok((FeMessage::ServerThrottle(reason), None));
619 }
620 guard
621 } else {
622 None
623 };
624 let message = self.stream.read_body().await?;
625 Ok((message, guard))
626 }
627 }
628 }
629
630 fn ready_for_query(&mut self) -> io::Result<()> {
632 self.stream.write_no_flush(BeMessage::ReadyForQuery(
633 self.session
634 .as_ref()
635 .map(|s| s.transaction_status())
636 .unwrap_or(TransactionStatus::Idle),
637 ))
638 }
639
640 async fn process_gss_msg(&mut self) -> PsqlResult<()> {
641 self.stream.write(BeMessage::EncryptionResponseNo).await?;
643 Ok(())
644 }
645
646 async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
647 if let Some(context) = self.tls_context.as_ref() {
648 self.stream.write(BeMessage::EncryptionResponseSsl).await?;
651 self.stream.upgrade_to_ssl(context).await?;
652 } else {
653 self.stream.write(BeMessage::EncryptionResponseNo).await?;
655 }
656
657 Ok(())
658 }
659
660 async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
661 if let Some(ref tls_config) = self.tls_config
663 && tls_config.enforce_ssl
664 && !self.stream.is_ssl_connection().await
665 {
666 return Err(PsqlError::StartupError(
667 "SSL connection is required but not established".into(),
668 ));
669 }
670
671 let db_name = msg
672 .config
673 .get("database")
674 .cloned()
675 .unwrap_or_else(|| "dev".to_owned());
676 let user_name = msg
677 .config
678 .get("user")
679 .cloned()
680 .unwrap_or_else(|| "root".to_owned());
681
682 let session = self
683 .session_mgr
684 .connect(&db_name, &user_name, self.peer_addr.clone())
685 .map_err(|e| PsqlError::StartupError(e.into()))?;
686
687 let application_name = msg.config.get("application_name");
688 if let Some(application_name) = application_name {
689 session
690 .set_config("application_name", application_name.clone())
691 .map_err(|e| PsqlError::StartupError(e.into()))?;
692 }
693
694 match session.user_authenticator() {
695 UserAuthenticator::None => {
696 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
697
698 self.stream
701 .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
702
703 self.stream.write_no_flush(BeMessage::ParameterStatus(
704 BeParameterStatusMessage::TimeZone(
705 &session
706 .get_config("timezone")
707 .map_err(|e| PsqlError::StartupError(e.into()))?,
708 ),
709 ))?;
710 self.stream
711 .write_parameter_status_msg_no_flush(&ParameterStatus {
712 application_name: application_name.cloned(),
713 })?;
714 self.ready_for_query()?;
715 }
716 UserAuthenticator::ClearText(_)
717 | UserAuthenticator::OAuth { .. }
718 | UserAuthenticator::Ldap(..) => {
719 self.stream
720 .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
721 }
722 UserAuthenticator::Md5WithSalt { salt, .. } => {
723 self.stream
724 .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
725 }
726 }
727
728 self.session = Some(session);
729 self.state = PgProtocolState::Regular;
730 Ok(())
731 }
732
733 async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
734 let session = self.session.as_ref().unwrap();
735 let authenticator = session.user_authenticator();
736 authenticator.authenticate(&msg.password).await?;
737 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
738 let timezone = session
739 .get_config("timezone")
740 .map_err(|e| PsqlError::StartupError(e.into()))?;
741 self.stream.write_no_flush(BeMessage::ParameterStatus(
742 BeParameterStatusMessage::TimeZone(&timezone),
743 ))?;
744 self.stream
745 .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
746 self.ready_for_query()?;
747 self.state = PgProtocolState::Regular;
748 Ok(())
749 }
750
751 fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
752 let session_id = (m.target_process_id, m.target_secret_key);
753 tracing::trace!("cancel query in session: {:?}", session_id);
754 self.session_mgr.cancel_queries_in_session(session_id);
755 self.session_mgr.cancel_creating_jobs_in_session(session_id);
756 self.is_terminate = true;
757 Ok(())
758 }
759
760 async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
761 let truncated_sql =
762 record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
763 let session = self.session.clone().unwrap();
764
765 session.check_idle_in_transaction_timeout()?;
766 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
768 self.inner_process_query_msg(sql, session.clone()).await
769 }
770
771 async fn inner_process_query_msg(
772 &mut self,
773 sql: Arc<str>,
774 session: Arc<SM::Session>,
775 ) -> PsqlResult<()> {
776 let stmts =
778 Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
779 drop(sql);
781 if stmts.is_empty() {
782 self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
783 }
784
785 for stmt in stmts {
787 self.inner_process_query_msg_one_stmt(stmt, session.clone())
788 .await?;
789 }
790 self.ready_for_query()?;
793 Ok(())
794 }
795
796 async fn inner_process_query_msg_one_stmt(
797 &mut self,
798 stmt: Statement,
799 session: Arc<SM::Session>,
800 ) -> PsqlResult<()> {
801 let session = session.clone();
802
803 let res = session.clone().run_one_query(stmt, Format::Text).await;
805
806 while let Some(notice) = session.next_notice().now_or_never() {
808 self.stream
809 .write_no_flush(BeMessage::NoticeResponse(¬ice))?;
810 }
811
812 let mut res = res.map_err(|e| PsqlError::SimpleQueryError(e.into()))?;
813
814 for notice in res.notices() {
815 self.stream
816 .write_no_flush(BeMessage::NoticeResponse(notice))?;
817 }
818
819 let status = res.status();
820 if let Some(ref application_name) = status.application_name {
821 self.stream.write_no_flush(BeMessage::ParameterStatus(
822 BeParameterStatusMessage::ApplicationName(application_name),
823 ))?;
824 }
825
826 if res.is_copy_query_to_stdout() {
827 self.stream
828 .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
829 let mut count = 0;
830 while let Some(row_set) = res.values_stream().next().await {
831 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
832 for row in row_set {
833 self.stream.write_no_flush(BeMessage::CopyData(&row))?;
834 count += 1;
835 }
836 }
837
838 self.stream.write_no_flush(BeMessage::CopyDone)?;
839
840 res.run_callback().await?;
842
843 self.stream
844 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
845 stmt_type: res.stmt_type(),
846 rows_cnt: count,
847 }))?;
848 } else if res.is_query() {
849 self.stream
850 .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
851
852 let mut rows_cnt = 0;
853
854 while let Some(row_set) = res.values_stream().next().await {
855 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
856 for row in row_set {
857 self.stream.write_no_flush(BeMessage::DataRow(&row))?;
858 rows_cnt += 1;
859 }
860 }
861
862 res.run_callback().await?;
864
865 self.stream
866 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
867 stmt_type: res.stmt_type(),
868 rows_cnt,
869 }))?;
870 } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
871 let first_row_set = res.values_stream().next().await;
872 let first_row_set = match first_row_set {
873 None => {
874 return Err(PsqlError::Uncategorized(
875 anyhow::anyhow!("no affected rows in output").into(),
876 ));
877 }
878 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
879 };
880 let affected_rows_str = first_row_set[0].values()[0]
881 .as_ref()
882 .expect("compute node should return affected rows in output");
883
884 assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
885 let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
886 .unwrap()
887 .parse()
888 .unwrap_or_default();
889
890 res.run_callback().await?;
892
893 self.stream
894 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
895 stmt_type: res.stmt_type(),
896 rows_cnt: affected_rows_cnt,
897 }))?;
898 } else {
899 res.run_callback().await?;
901
902 self.stream
903 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
904 stmt_type: res.stmt_type(),
905 rows_cnt: 0,
906 }))?;
907 }
908
909 Ok(())
910 }
911
912 fn process_terminate(&mut self) {
913 self.is_terminate = true;
914 }
915
916 fn process_health_check(&mut self) {
917 tracing::debug!("health check");
918 self.is_terminate = true;
919 }
920
921 async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
922 let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
923 record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
924 let session = self.session.clone().unwrap();
925 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
926 let type_ids = std::mem::take(&mut msg.type_ids);
927 drop(msg);
929 self.inner_process_parse_msg(session, sql, statement_name, type_ids)
930 .await?;
931 Ok(())
932 }
933
934 async fn inner_process_parse_msg(
935 &mut self,
936 session: Arc<SM::Session>,
937 sql: Arc<str>,
938 statement_name: String,
939 type_ids: Vec<i32>,
940 ) -> PsqlResult<()> {
941 if statement_name.is_empty() {
942 self.unnamed_prepare_statement.take();
945 } else if self.prepare_statement_store.contains_key(&statement_name) {
946 return Err(PsqlError::ExtendedPrepareError(
947 "Duplicated statement name".into(),
948 ));
949 }
950
951 let stmt = {
952 let stmts = Parser::parse_sql(&sql)
953 .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
954 drop(sql);
955 if stmts.len() > 1 {
956 return Err(PsqlError::ExtendedPrepareError(
957 "Only one statement is allowed in extended query mode".into(),
958 ));
959 }
960
961 stmts.into_iter().next()
962 };
963
964 let param_types: Vec<Option<DataType>> = type_ids
965 .iter()
966 .map(|&id| {
967 if id == 0 {
970 Ok(None)
971 } else {
972 DataType::from_oid(id)
973 .map(Some)
974 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
975 }
976 })
977 .try_collect()?;
978
979 let prepare_statement = session
980 .parse(stmt, param_types)
981 .await
982 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))?;
983
984 if statement_name.is_empty() {
985 self.unnamed_prepare_statement.replace(prepare_statement);
986 } else {
987 self.prepare_statement_store
988 .insert(statement_name.clone(), prepare_statement);
989 }
990
991 self.statement_portal_dependency
992 .entry(statement_name)
993 .or_default()
994 .clear();
995
996 self.stream.write_no_flush(BeMessage::ParseComplete)?;
997 Ok(())
998 }
999
1000 fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
1001 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
1002 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1003 let session = self.session.clone().unwrap();
1004
1005 if self.portal_store.contains_key(&portal_name) {
1006 return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
1007 }
1008
1009 let prepare_statement = self.get_statement(&statement_name)?;
1010
1011 let result_formats = msg
1012 .result_format_codes
1013 .iter()
1014 .map(|&format_code| Format::from_i16(format_code))
1015 .try_collect()?;
1016 let param_formats = msg
1017 .param_format_codes
1018 .iter()
1019 .map(|&format_code| Format::from_i16(format_code))
1020 .try_collect()?;
1021
1022 let portal = session
1023 .bind(prepare_statement, msg.params, param_formats, result_formats)
1024 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1025
1026 if portal_name.is_empty() {
1027 self.result_cache.remove(&portal_name);
1028 self.unnamed_portal.replace(portal);
1029 } else {
1030 assert!(
1031 !self.result_cache.contains_key(&portal_name),
1032 "Named portal never can be overridden."
1033 );
1034 self.portal_store.insert(portal_name.clone(), portal);
1035 }
1036
1037 self.statement_portal_dependency
1038 .get_mut(&statement_name)
1039 .unwrap()
1040 .push(portal_name);
1041
1042 self.stream.write_no_flush(BeMessage::BindComplete)?;
1043 Ok(())
1044 }
1045
1046 async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1047 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1048 let row_max = msg.max_rows as usize;
1049 drop(msg);
1050 let session = self.session.clone().unwrap();
1051
1052 match self.result_cache.remove(&portal_name) {
1053 Some(mut result_cache) => {
1054 assert!(self.portal_store.contains_key(&portal_name));
1055
1056 let is_consume_completed =
1057 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1058
1059 if !is_consume_completed {
1060 self.result_cache.insert(portal_name, result_cache);
1061 }
1062 }
1063 _ => {
1064 let portal = self.get_portal(&portal_name)?;
1065 let sql = format!("{}", portal);
1066 let truncated_sql =
1067 record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
1068 drop(sql);
1069
1070 session.check_idle_in_transaction_timeout()?;
1071 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1073 let result = session.clone().execute(portal).await;
1074
1075 let pg_response = result.map_err(|e| PsqlError::ExtendedExecuteError(e.into()))?;
1076 let mut result_cache = ResultCache::new(pg_response);
1077 let is_consume_completed =
1078 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1079 if !is_consume_completed {
1080 self.result_cache.insert(portal_name, result_cache);
1081 }
1082 }
1083 }
1084
1085 Ok(())
1086 }
1087
1088 fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1089 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1090 let session = self.session.clone().unwrap();
1091 assert!(msg.kind == b'S' || msg.kind == b'P');
1095 if msg.kind == b'S' {
1096 let prepare_statement = self.get_statement(&name)?;
1097
1098 let (param_types, row_descriptions) = self
1099 .session
1100 .clone()
1101 .unwrap()
1102 .describe_statement(prepare_statement)
1103 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1104 self.stream.write_no_flush(BeMessage::ParameterDescription(
1105 ¶m_types.iter().map(|t| t.to_oid()).collect_vec(),
1106 ))?;
1107
1108 if row_descriptions.is_empty() {
1109 self.stream.write_no_flush(BeMessage::NoData)?;
1112 } else {
1113 self.stream
1114 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1115 }
1116 } else if msg.kind == b'P' {
1117 let portal = self.get_portal(&name)?;
1118
1119 let row_descriptions = session
1120 .describe_portal(portal)
1121 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1122
1123 if row_descriptions.is_empty() {
1124 self.stream.write_no_flush(BeMessage::NoData)?;
1127 } else {
1128 self.stream
1129 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1130 }
1131 }
1132 Ok(())
1133 }
1134
1135 fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1136 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1137 assert!(msg.kind == b'S' || msg.kind == b'P');
1138 if msg.kind == b'S' {
1139 if name.is_empty() {
1140 self.unnamed_prepare_statement = None;
1141 } else {
1142 self.prepare_statement_store.remove(&name);
1143 }
1144 for portal_name in self
1145 .statement_portal_dependency
1146 .remove(&name)
1147 .unwrap_or_default()
1148 {
1149 self.remove_portal(&portal_name);
1150 }
1151 } else if msg.kind == b'P' {
1152 self.remove_portal(&name);
1153 }
1154 self.stream.write_no_flush(BeMessage::CloseComplete)?;
1155 Ok(())
1156 }
1157
1158 fn remove_portal(&mut self, portal_name: &str) {
1159 if portal_name.is_empty() {
1160 self.unnamed_portal = None;
1161 } else {
1162 self.portal_store.remove(portal_name);
1163 }
1164 self.result_cache.remove(portal_name);
1165 }
1166
1167 fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1168 if portal_name.is_empty() {
1169 Ok(self
1170 .unnamed_portal
1171 .as_ref()
1172 .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1173 .clone())
1174 } else {
1175 Ok(self
1176 .portal_store
1177 .get(portal_name)
1178 .ok_or_else(|| {
1179 PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1180 })?
1181 .clone())
1182 }
1183 }
1184
1185 fn get_statement(
1186 &self,
1187 statement_name: &str,
1188 ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1189 if statement_name.is_empty() {
1190 Ok(self
1191 .unnamed_prepare_statement
1192 .as_ref()
1193 .ok_or_else(|| {
1194 PsqlError::Uncategorized("unnamed prepare statement not found".into())
1195 })?
1196 .clone())
1197 } else {
1198 Ok(self
1199 .prepare_statement_store
1200 .get(statement_name)
1201 .ok_or_else(|| {
1202 PsqlError::Uncategorized(
1203 format!("Prepare statement {} not found", statement_name).into(),
1204 )
1205 })?
1206 .clone())
1207 }
1208 }
1209}
1210
1211enum PgStreamInner<S> {
1212 Placeholder,
1214 Unencrypted(S),
1216 Ssl(SslStream<S>),
1218}
1219
1220pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1222impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1223
1224pub struct PgStream<S> {
1229 stream: Arc<Mutex<PgStreamInner<S>>>,
1231 write_buf: BytesMut,
1233 read_header: Option<FeMessageHeader>,
1234}
1235
1236impl<S> PgStream<S> {
1237 pub fn new(stream: S) -> Self {
1239 const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1240
1241 Self {
1242 stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1243 write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1244 read_header: None,
1245 }
1246 }
1247
1248 async fn is_ssl_connection(&self) -> bool {
1250 let stream = self.stream.lock().await;
1251 matches!(*stream, PgStreamInner::Ssl(_))
1252 }
1253}
1254
1255impl<S> Clone for PgStream<S> {
1256 fn clone(&self) -> Self {
1257 Self {
1258 stream: Arc::clone(&self.stream),
1259 write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1260 read_header: self.read_header.clone(),
1261 }
1262 }
1263}
1264
1265#[derive(Debug, Default, Clone)]
1282pub struct ParameterStatus {
1283 pub application_name: Option<String>,
1284}
1285
1286impl<S> PgStream<S>
1287where
1288 S: PgByteStream,
1289{
1290 async fn read_startup(&mut self) -> io::Result<FeMessage> {
1291 let mut stream = self.stream.lock().await;
1292 match &mut *stream {
1293 PgStreamInner::Placeholder => unreachable!(),
1294 PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1295 PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1296 }
1297 }
1298
1299 async fn read_header(&mut self) -> io::Result<()> {
1300 let mut stream = self.stream.lock().await;
1301 match &mut *stream {
1302 PgStreamInner::Placeholder => unreachable!(),
1303 PgStreamInner::Unencrypted(stream) => {
1304 self.read_header = Some(FeMessage::read_header(stream).await?);
1305 Ok(())
1306 }
1307 PgStreamInner::Ssl(ssl_stream) => {
1308 self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1309 Ok(())
1310 }
1311 }
1312 }
1313
1314 async fn read_body(&mut self) -> io::Result<FeMessage> {
1315 let mut stream = self.stream.lock().await;
1316 let header = self
1317 .read_header
1318 .take()
1319 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1320 match &mut *stream {
1321 PgStreamInner::Placeholder => unreachable!(),
1322 PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1323 PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1324 }
1325 }
1326
1327 async fn skip_body(&mut self) -> io::Result<()> {
1328 let mut stream = self.stream.lock().await;
1329 let header = self
1330 .read_header
1331 .take()
1332 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1333 match &mut *stream {
1334 PgStreamInner::Placeholder => unreachable!(),
1335 PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1336 PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1337 }
1338 }
1339
1340 fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1341 self.write_no_flush(BeMessage::ParameterStatus(
1342 BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1343 ))?;
1344 self.write_no_flush(BeMessage::ParameterStatus(
1345 BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1346 ))?;
1347 self.write_no_flush(BeMessage::ParameterStatus(
1348 BeParameterStatusMessage::ServerVersion(PG_VERSION),
1349 ))?;
1350 if let Some(application_name) = &status.application_name {
1351 self.write_no_flush(BeMessage::ParameterStatus(
1352 BeParameterStatusMessage::ApplicationName(application_name),
1353 ))?;
1354 }
1355 Ok(())
1356 }
1357
1358 pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1359 BeMessage::write(&mut self.write_buf, message)
1360 }
1361
1362 async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1363 self.write_no_flush(message)?;
1364 self.flush().await?;
1365 Ok(())
1366 }
1367
1368 async fn flush(&mut self) -> io::Result<()> {
1369 let mut stream = self.stream.lock().await;
1370 match &mut *stream {
1371 PgStreamInner::Placeholder => unreachable!(),
1372 PgStreamInner::Unencrypted(stream) => {
1373 stream.write_all(&self.write_buf).await?;
1374 stream.flush().await?;
1375 }
1376 PgStreamInner::Ssl(ssl_stream) => {
1377 ssl_stream.write_all(&self.write_buf).await?;
1378 ssl_stream.flush().await?;
1379 }
1380 }
1381 self.write_buf.clear();
1382 Ok(())
1383 }
1384}
1385
1386impl<S> PgStream<S>
1387where
1388 S: PgByteStream,
1389{
1390 async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1392 let mut stream = self.stream.lock().await;
1393
1394 match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1395 PgStreamInner::Unencrypted(unencrypted_stream) => {
1396 let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1397 let mut ssl_stream =
1398 tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1399
1400 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1401 tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1402 let _ = ssl_stream.shutdown().await;
1403 return Err(e.into());
1404 }
1405
1406 *stream = PgStreamInner::Ssl(ssl_stream);
1407 }
1408 PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1409 PgStreamInner::Placeholder => unreachable!(),
1410 }
1411
1412 Ok(())
1413 }
1414}
1415
1416fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1417 let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1418
1419 let key_path = &tls_config.key;
1420 let cert_path = &tls_config.cert;
1421
1422 acceptor
1425 .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1426 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1427 acceptor
1428 .set_ca_file(cert_path)
1429 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1430 acceptor
1431 .set_certificate_chain_file(cert_path)
1432 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1433 let acceptor = acceptor.build();
1434
1435 Ok(acceptor.into_context())
1436}
1437
1438pub mod truncated_fmt {
1439 use std::fmt::*;
1440
1441 struct TruncatedFormatter<'a, 'b> {
1442 remaining: usize,
1443 finished: bool,
1444 f: &'a mut Formatter<'b>,
1445 }
1446 impl Write for TruncatedFormatter<'_, '_> {
1447 fn write_str(&mut self, s: &str) -> Result {
1448 if self.finished {
1449 return Ok(());
1450 }
1451
1452 if self.remaining < s.len() {
1453 let actual = s.floor_char_boundary(self.remaining);
1454 self.f.write_str(&s[0..actual])?;
1455 self.remaining -= actual;
1456 self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1457 self.finished = true; } else {
1459 self.f.write_str(s)?;
1460 self.remaining -= s.len();
1461 }
1462 Ok(())
1463 }
1464 }
1465
1466 pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1467
1468 impl<T> Debug for TruncatedFmt<'_, T>
1469 where
1470 T: Debug,
1471 {
1472 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1473 TruncatedFormatter {
1474 remaining: self.1,
1475 finished: false,
1476 f,
1477 }
1478 .write_fmt(format_args!("{:?}", self.0))
1479 }
1480 }
1481
1482 impl<T> Display for TruncatedFmt<'_, T>
1483 where
1484 T: Display,
1485 {
1486 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1487 TruncatedFormatter {
1488 remaining: self.1,
1489 finished: false,
1490 f,
1491 }
1492 .write_fmt(format_args!("{}", self.0))
1493 }
1494 }
1495
1496 #[cfg(test)]
1497 mod tests {
1498 use super::*;
1499
1500 #[test]
1501 fn test_trunc_utf8() {
1502 assert_eq!(
1503 format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1504 "select '...(truncated,14)",
1505 );
1506 }
1507 }
1508}
1509
1510#[cfg(test)]
1511mod tests {
1512 use std::collections::HashSet;
1513
1514 use super::*;
1515
1516 #[test]
1517 fn test_redact_parsable_sql() {
1518 let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1519 let sql = r"
1520 create source temp (k bigint, v varchar) with (
1521 connector = 'datagen',
1522 v1 = 123,
1523 v2 = 'with',
1524 v3 = false,
1525 v4 = '',
1526 ) FORMAT plain ENCODE json (a='1',b='2')
1527 ";
1528 assert_eq!(
1529 redact_sql(sql, keywords),
1530 "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1531 );
1532 }
1533}