1use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
46use crate::net::AddressRef;
47use crate::pg_extended::ResultCache;
48use crate::pg_message::{
49 BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
50 FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
51 FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
52};
53use crate::pg_server::{Session, SessionManager, UserAuthenticator};
54use crate::types::Format;
55
56static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
59 LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
60 Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
61 _ => {
62 if cfg!(debug_assertions) {
63 65536
64 } else {
65 1024
66 }
67 }
68 });
69
70tokio::task_local! {
71 pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
73}
74
75pub struct PgProtocol<S, SM>
78where
79 SM: SessionManager,
80{
81 stream: PgStream<S>,
83 state: PgProtocolState,
85 is_terminate: bool,
87
88 session_mgr: Arc<SM>,
89 session: Option<Arc<SM::Session>>,
90
91 result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
92 unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
93 prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
94 unnamed_portal: Option<<SM::Session as Session>::Portal>,
95 portal_store: HashMap<String, <SM::Session as Session>::Portal>,
96 statement_portal_dependency: HashMap<String, Vec<String>>,
99
100 tls_context: Option<SslContext>,
103
104 tls_config: Option<TlsConfig>,
106
107 ignore_util_sync: bool,
110
111 peer_addr: AddressRef,
113
114 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
115 message_memory_manager: MessageMemoryManagerRef,
116}
117
118#[derive(Debug, Clone)]
120pub struct TlsConfig {
121 pub cert: String,
123 pub key: String,
125 pub enforce_ssl: bool,
127}
128
129impl TlsConfig {
130 pub fn new_default() -> Option<Self> {
131 let cert = std::env::var("RW_SSL_CERT").ok()?;
132 let key = std::env::var("RW_SSL_KEY").ok()?;
133 let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
134 tracing::info!(
135 "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
136 cert,
137 key,
138 enforce_ssl
139 );
140 Some(Self {
141 cert,
142 key,
143 enforce_ssl,
144 })
145 }
146}
147
148impl<S, SM> Drop for PgProtocol<S, SM>
149where
150 SM: SessionManager,
151{
152 fn drop(&mut self) {
153 if let Some(session) = &self.session {
154 self.session_mgr.end_session(session);
156 }
157 }
158}
159
160enum PgProtocolState {
162 Startup,
163 Regular,
164}
165
166pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
170 let without_null = if b.last() == Some(&0) {
171 &b[..b.len() - 1]
172 } else {
173 &b[..]
174 };
175 std::str::from_utf8(without_null)
176}
177
178fn record_sql_in_span(
180 sql: &str,
181 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
182) -> String {
183 let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
184 && !keywords.is_empty()
185 {
186 redact_sql(sql, keywords)
187 } else {
188 sql.to_owned()
189 };
190 let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
191 tracing::Span::current().record("sql", tracing::field::display(&truncated));
192 truncated.to_string()
193}
194
195fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
197 match Parser::parse_sql(sql) {
198 Ok(sqls) => sqls
199 .into_iter()
200 .map(|sql| sql.to_redacted_string(keywords.clone()))
201 .join(";"),
202 Err(_) => sql.to_owned(),
203 }
204}
205
206#[derive(Clone)]
207pub struct ConnectionContext {
208 pub tls_config: Option<TlsConfig>,
209 pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
210 pub message_memory_manager: MessageMemoryManagerRef,
211}
212
213impl<S, SM> PgProtocol<S, SM>
214where
215 S: PgByteStream,
216 SM: SessionManager,
217{
218 pub fn new(
219 stream: S,
220 session_mgr: Arc<SM>,
221 peer_addr: AddressRef,
222 context: ConnectionContext,
223 ) -> Self {
224 let ConnectionContext {
225 tls_config,
226 redact_sql_option_keywords,
227 message_memory_manager,
228 } = context;
229 Self {
230 stream: PgStream::new(stream),
231 is_terminate: false,
232 state: PgProtocolState::Startup,
233 session_mgr,
234 session: None,
235 tls_context: tls_config
236 .as_ref()
237 .and_then(|e| build_ssl_ctx_from_config(e).ok()),
238 tls_config: tls_config.clone(),
239 result_cache: Default::default(),
240 unnamed_prepare_statement: Default::default(),
241 prepare_statement_store: Default::default(),
242 unnamed_portal: Default::default(),
243 portal_store: Default::default(),
244 statement_portal_dependency: Default::default(),
245 ignore_util_sync: false,
246 peer_addr,
247 redact_sql_option_keywords,
248 message_memory_manager,
249 }
250 }
251
252 pub async fn run(&mut self) {
254 let mut notice_fut = None;
255
256 loop {
257 if notice_fut.is_none()
259 && let Some(session) = self.session.clone()
260 {
261 let mut stream = self.stream.clone();
262 notice_fut = Some(Box::pin(async move {
263 loop {
264 let notice = session.next_notice().await;
265 if let Err(e) = stream.write(BeMessage::NoticeResponse(¬ice)).await {
266 tracing::error!(error = %e.as_report(), notice, "failed to send notice");
267 }
268 }
269 }));
270 }
271
272 let process = std::pin::pin!(async {
274 let (msg, _memory_guard) = match self.read_message().await {
275 Ok(msg) => msg,
276 Err(e) => {
277 tracing::error!(error = %e.as_report(), "error when reading message");
278 return true; }
280 };
281 tracing::trace!(?msg, "received message");
282 self.process(msg).await
283 });
284
285 let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
286 tokio::select! {
287 _ = notice_fut => unreachable!(),
288 terminated = process => terminated,
289 }
290 } else {
291 process.await
292 };
293
294 if terminated {
295 break;
296 }
297 }
298 }
299
300 pub async fn process(&mut self, msg: FeMessage) -> bool {
302 self.do_process(msg).await.is_none() || self.is_terminate
303 }
304
305 fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
313 let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
314 return tracing::Span::none();
315 };
316
317 let mode = match msg {
318 FeMessage::Query(_) => "simple query",
319 FeMessage::Parse(_) => "extended query parse",
320 FeMessage::Execute(_) => "extended query execute",
321 _ => return tracing::Span::none(),
322 };
323
324 tracing::info_span!(
325 target: PGWIRE_ROOT_SPAN_TARGET,
326 "handle_query",
327 mode,
328 session_id,
329 sql = tracing::field::Empty, )
331 }
332
333 async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
337 let span = self.root_span_for_msg(&msg);
338 let weak_session = self
339 .session
340 .as_ref()
341 .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
342
343 let fut = Box::pin(self.do_process_inner(msg));
348
349 let fut = async move {
351 if let Some(session) = weak_session {
352 CURRENT_SESSION.scope(session, fut).await
353 } else {
354 fut.await
355 }
356 };
357
358 let fut = async move {
360 AssertUnwindSafe(fut)
361 .rw_catch_unwind()
362 .await
363 .unwrap_or_else(|payload| {
364 Err(PsqlError::Panic(
365 panic_message::panic_message(&payload).to_owned(),
366 ))
367 })
368 };
369
370 let fut = async move {
372 let period = *SLOW_QUERY_LOG_PERIOD;
373 let mut fut = std::pin::pin!(fut);
374 let mut elapsed = Duration::ZERO;
375
376 loop {
378 match tokio::time::timeout(period, &mut fut).await {
379 Ok(result) => break result,
380 Err(_) => {
381 elapsed += period;
382 tracing::info!(
383 target: PGWIRE_SLOW_QUERY_LOG,
384 elapsed = %format_args!("{}ms", elapsed.as_millis()),
385 "slow query"
386 );
387 }
388 }
389 }
390 };
391
392 let fut = async move {
394 let start = Instant::now();
395 let result = fut.await;
396 let elapsed = start.elapsed();
397
398 if let Err(error) = &result {
402 if cfg!(debug_assertions) && !Deployment::current().is_ci() {
403 tracing::error!(error = ?error.as_report(), "error when process message");
409 } else {
410 tracing::error!(error = %error.as_report(), "error when process message");
411 }
412 }
413
414 if !tracing::Span::current().is_none() {
417 tracing::info!(
418 target: PGWIRE_QUERY_LOG,
419 status = if result.is_ok() { "ok" } else { "err" },
420 time = %format_args!("{}ms", elapsed.as_millis()),
421 );
422 }
423
424 result
425 };
426
427 let fut = fut.instrument(span);
429
430 match fut.await {
432 Ok(()) => Some(()),
433 Err(e) => {
434 match e {
435 PsqlError::IoError(io_err) => {
436 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
437 return None;
438 }
439 }
440
441 PsqlError::SslError(_) => {
442 return None;
445 }
446
447 PsqlError::StartupError(_) | PsqlError::PasswordError => {
448 self.stream
449 .write_no_flush(BeMessage::ErrorResponse(&e))
450 .ok()?;
451 let _ = self.stream.flush().await;
452 return None;
453 }
454
455 PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
456 self.stream
457 .write_no_flush(BeMessage::ErrorResponse(&e))
458 .ok()?;
459 self.ready_for_query().ok()?;
460 }
461
462 PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
463 self.stream
464 .write_no_flush(BeMessage::ErrorResponse(&e))
465 .ok()?;
466 let _ = self.stream.flush().await;
467
468 return None;
473 }
474
475 PsqlError::Uncategorized(_)
476 | PsqlError::ExtendedPrepareError(_)
477 | PsqlError::ExtendedExecuteError(_) => {
478 self.stream
479 .write_no_flush(BeMessage::ErrorResponse(&e))
480 .ok()?;
481 }
482 }
483 let _ = self.stream.flush().await;
484 Some(())
485 }
486 }
487 }
488
489 async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
490 if self.ignore_util_sync {
492 if let FeMessage::Sync = msg {
493 } else {
494 tracing::trace!("ignore message {:?} until sync.", msg);
495 return Ok(());
496 }
497 }
498
499 match msg {
500 FeMessage::Gss => self.process_gss_msg().await?,
501 FeMessage::Ssl => self.process_ssl_msg().await?,
502 FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
503 FeMessage::Password(msg) => self.process_password_msg(msg).await?,
504 FeMessage::Query(query_msg) => {
505 let sql = Arc::from(query_msg.get_sql()?);
506 drop(query_msg);
508 self.process_query_msg(sql).await?
509 }
510 FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
511 FeMessage::Terminate => self.process_terminate(),
512 FeMessage::Parse(m) => {
513 if let Err(err) = self.process_parse_msg(m).await {
514 self.ignore_util_sync = true;
515 return Err(err);
516 }
517 }
518 FeMessage::Bind(m) => {
519 if let Err(err) = self.process_bind_msg(m) {
520 self.ignore_util_sync = true;
521 return Err(err);
522 }
523 }
524 FeMessage::Execute(m) => {
525 if let Err(err) = self.process_execute_msg(m).await {
526 self.ignore_util_sync = true;
527 return Err(err);
528 }
529 }
530 FeMessage::Describe(m) => {
531 if let Err(err) = self.process_describe_msg(m) {
532 self.ignore_util_sync = true;
533 return Err(err);
534 }
535 }
536 FeMessage::Sync => {
537 self.ignore_util_sync = false;
538 self.ready_for_query()?
539 }
540 FeMessage::Close(m) => {
541 if let Err(err) = self.process_close_msg(m) {
542 self.ignore_util_sync = true;
543 return Err(err);
544 }
545 }
546 FeMessage::Flush => {
547 if let Err(err) = self.stream.flush().await {
548 self.ignore_util_sync = true;
549 return Err(err.into());
550 }
551 }
552 FeMessage::HealthCheck => self.process_health_check(),
553 FeMessage::ServerThrottle(reason) => match reason {
554 ServerThrottleReason::TooLargeMessage => {
555 return Err(PsqlError::ServerThrottle(format!(
556 "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
557 self.message_memory_manager.max_filter_bytes
558 )));
559 }
560 ServerThrottleReason::TooManyMemoryUsage => {
561 return Err(PsqlError::ServerThrottle(format!(
562 "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
563 self.message_memory_manager.max_running_bytes
564 )));
565 }
566 },
567 }
568 self.stream.flush().await?;
569 Ok(())
570 }
571
572 pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
573 match self.state {
574 PgProtocolState::Startup => self
575 .stream
576 .read_startup()
577 .await
578 .map(|message: FeMessage| (message, None)),
579 PgProtocolState::Regular => {
580 self.stream.read_header().await?;
581 let guard = if let Some(ref header) = self.stream.read_header {
582 let payload_len = std::cmp::max(header.payload_len, 0) as u64;
583 let (reason, guard) = self.message_memory_manager.add(payload_len);
584 if let Some(reason) = reason {
585 drop(guard);
587 self.stream.skip_body().await?;
588 return Ok((FeMessage::ServerThrottle(reason), None));
589 }
590 guard
591 } else {
592 None
593 };
594 let message = self.stream.read_body().await?;
595 Ok((message, guard))
596 }
597 }
598 }
599
600 fn ready_for_query(&mut self) -> io::Result<()> {
602 self.stream.write_no_flush(BeMessage::ReadyForQuery(
603 self.session
604 .as_ref()
605 .map(|s| s.transaction_status())
606 .unwrap_or(TransactionStatus::Idle),
607 ))
608 }
609
610 async fn process_gss_msg(&mut self) -> PsqlResult<()> {
611 self.stream.write(BeMessage::EncryptionResponseNo).await?;
613 Ok(())
614 }
615
616 async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
617 if let Some(context) = self.tls_context.as_ref() {
618 self.stream.write(BeMessage::EncryptionResponseSsl).await?;
621 self.stream.upgrade_to_ssl(context).await?;
622 } else {
623 self.stream.write(BeMessage::EncryptionResponseNo).await?;
625 }
626
627 Ok(())
628 }
629
630 async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
631 if let Some(ref tls_config) = self.tls_config
633 && tls_config.enforce_ssl
634 && !self.stream.is_ssl_connection().await
635 {
636 return Err(PsqlError::StartupError(
637 "SSL connection is required but not established".into(),
638 ));
639 }
640
641 let db_name = msg
642 .config
643 .get("database")
644 .cloned()
645 .unwrap_or_else(|| "dev".to_owned());
646 let user_name = msg
647 .config
648 .get("user")
649 .cloned()
650 .unwrap_or_else(|| "root".to_owned());
651
652 let session = self
653 .session_mgr
654 .connect(&db_name, &user_name, self.peer_addr.clone())
655 .map_err(PsqlError::StartupError)?;
656
657 let application_name = msg.config.get("application_name");
658 if let Some(application_name) = application_name {
659 session
660 .set_config("application_name", application_name.clone())
661 .map_err(PsqlError::StartupError)?;
662 }
663
664 match session.user_authenticator() {
665 UserAuthenticator::None => {
666 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
667
668 self.stream
671 .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
672
673 self.stream.write_no_flush(BeMessage::ParameterStatus(
674 BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
675 ))?;
676 self.stream
677 .write_parameter_status_msg_no_flush(&ParameterStatus {
678 application_name: application_name.cloned(),
679 })?;
680 self.ready_for_query()?;
681 }
682 UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth(_) => {
683 self.stream
684 .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
685 }
686 UserAuthenticator::Md5WithSalt { salt, .. } => {
687 self.stream
688 .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
689 }
690 }
691
692 self.session = Some(session);
693 self.state = PgProtocolState::Regular;
694 Ok(())
695 }
696
697 async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
698 let session = self.session.as_ref().unwrap();
699 let authenticator = session.user_authenticator();
700 authenticator.authenticate(&msg.password).await?;
701 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
702 self.stream.write_no_flush(BeMessage::ParameterStatus(
703 BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
704 ))?;
705 self.stream
706 .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
707 self.ready_for_query()?;
708 self.state = PgProtocolState::Regular;
709 Ok(())
710 }
711
712 fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
713 let session_id = (m.target_process_id, m.target_secret_key);
714 tracing::trace!("cancel query in session: {:?}", session_id);
715 self.session_mgr.cancel_queries_in_session(session_id);
716 self.session_mgr.cancel_creating_jobs_in_session(session_id);
717 self.is_terminate = true;
718 Ok(())
719 }
720
721 async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
722 let truncated_sql = record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
723 let session = self.session.clone().unwrap();
724
725 session.check_idle_in_transaction_timeout()?;
726 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
728 self.inner_process_query_msg(sql, session.clone()).await
729 }
730
731 async fn inner_process_query_msg(
732 &mut self,
733 sql: Arc<str>,
734 session: Arc<SM::Session>,
735 ) -> PsqlResult<()> {
736 let stmts =
738 Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
739 drop(sql);
741 if stmts.is_empty() {
742 self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
743 }
744
745 for stmt in stmts {
747 self.inner_process_query_msg_one_stmt(stmt, session.clone())
748 .await?;
749 }
750 self.ready_for_query()?;
753 Ok(())
754 }
755
756 async fn inner_process_query_msg_one_stmt(
757 &mut self,
758 stmt: Statement,
759 session: Arc<SM::Session>,
760 ) -> PsqlResult<()> {
761 let session = session.clone();
762
763 let res = session.clone().run_one_query(stmt, Format::Text).await;
765
766 while let Some(notice) = session.next_notice().now_or_never() {
768 self.stream
769 .write_no_flush(BeMessage::NoticeResponse(¬ice))?;
770 }
771
772 let mut res = res.map_err(PsqlError::SimpleQueryError)?;
773
774 for notice in res.notices() {
775 self.stream
776 .write_no_flush(BeMessage::NoticeResponse(notice))?;
777 }
778
779 let status = res.status();
780 if let Some(ref application_name) = status.application_name {
781 self.stream.write_no_flush(BeMessage::ParameterStatus(
782 BeParameterStatusMessage::ApplicationName(application_name),
783 ))?;
784 }
785
786 if res.is_copy_query_to_stdout() {
787 self.stream
788 .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
789 let mut count = 0;
790 while let Some(row_set) = res.values_stream().next().await {
791 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
792 for row in row_set {
793 self.stream.write_no_flush(BeMessage::CopyData(&row))?;
794 count += 1;
795 }
796 }
797
798 self.stream.write_no_flush(BeMessage::CopyDone)?;
799
800 res.run_callback().await?;
802
803 self.stream
804 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
805 stmt_type: res.stmt_type(),
806 rows_cnt: count,
807 }))?;
808 } else if res.is_query() {
809 self.stream
810 .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
811
812 let mut rows_cnt = 0;
813
814 while let Some(row_set) = res.values_stream().next().await {
815 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
816 for row in row_set {
817 self.stream.write_no_flush(BeMessage::DataRow(&row))?;
818 rows_cnt += 1;
819 }
820 }
821
822 res.run_callback().await?;
824
825 self.stream
826 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
827 stmt_type: res.stmt_type(),
828 rows_cnt,
829 }))?;
830 } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
831 let first_row_set = res.values_stream().next().await;
832 let first_row_set = match first_row_set {
833 None => {
834 return Err(PsqlError::Uncategorized(
835 anyhow::anyhow!("no affected rows in output").into(),
836 ));
837 }
838 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
839 };
840 let affected_rows_str = first_row_set[0].values()[0]
841 .as_ref()
842 .expect("compute node should return affected rows in output");
843
844 assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
845 let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
846 .unwrap()
847 .parse()
848 .unwrap_or_default();
849
850 res.run_callback().await?;
852
853 self.stream
854 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
855 stmt_type: res.stmt_type(),
856 rows_cnt: affected_rows_cnt,
857 }))?;
858 } else {
859 res.run_callback().await?;
861
862 self.stream
863 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
864 stmt_type: res.stmt_type(),
865 rows_cnt: 0,
866 }))?;
867 }
868
869 Ok(())
870 }
871
872 fn process_terminate(&mut self) {
873 self.is_terminate = true;
874 }
875
876 fn process_health_check(&mut self) {
877 tracing::debug!("health check");
878 self.is_terminate = true;
879 }
880
881 async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
882 let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
883 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
884 let session = self.session.clone().unwrap();
885 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
886 let type_ids = std::mem::take(&mut msg.type_ids);
887 drop(msg);
889 self.inner_process_parse_msg(session, sql, statement_name, type_ids)
890 .await?;
891 Ok(())
892 }
893
894 async fn inner_process_parse_msg(
895 &mut self,
896 session: Arc<SM::Session>,
897 sql: Arc<str>,
898 statement_name: String,
899 type_ids: Vec<i32>,
900 ) -> PsqlResult<()> {
901 if statement_name.is_empty() {
902 self.unnamed_prepare_statement.take();
905 } else if self.prepare_statement_store.contains_key(&statement_name) {
906 return Err(PsqlError::ExtendedPrepareError(
907 "Duplicated statement name".into(),
908 ));
909 }
910
911 let stmt = {
912 let stmts = Parser::parse_sql(&sql)
913 .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
914 drop(sql);
915 if stmts.len() > 1 {
916 return Err(PsqlError::ExtendedPrepareError(
917 "Only one statement is allowed in extended query mode".into(),
918 ));
919 }
920
921 stmts.into_iter().next()
922 };
923
924 let param_types: Vec<Option<DataType>> = type_ids
925 .iter()
926 .map(|&id| {
927 if id == 0 {
930 Ok(None)
931 } else {
932 DataType::from_oid(id)
933 .map(Some)
934 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
935 }
936 })
937 .try_collect()?;
938
939 let prepare_statement = session
940 .parse(stmt, param_types)
941 .await
942 .map_err(PsqlError::ExtendedPrepareError)?;
943
944 if statement_name.is_empty() {
945 self.unnamed_prepare_statement.replace(prepare_statement);
946 } else {
947 self.prepare_statement_store
948 .insert(statement_name.clone(), prepare_statement);
949 }
950
951 self.statement_portal_dependency
952 .entry(statement_name)
953 .or_default()
954 .clear();
955
956 self.stream.write_no_flush(BeMessage::ParseComplete)?;
957 Ok(())
958 }
959
960 fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
961 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
962 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
963 let session = self.session.clone().unwrap();
964
965 if self.portal_store.contains_key(&portal_name) {
966 return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
967 }
968
969 let prepare_statement = self.get_statement(&statement_name)?;
970
971 let result_formats = msg
972 .result_format_codes
973 .iter()
974 .map(|&format_code| Format::from_i16(format_code))
975 .try_collect()?;
976 let param_formats = msg
977 .param_format_codes
978 .iter()
979 .map(|&format_code| Format::from_i16(format_code))
980 .try_collect()?;
981
982 let portal = session
983 .bind(prepare_statement, msg.params, param_formats, result_formats)
984 .map_err(PsqlError::Uncategorized)?;
985
986 if portal_name.is_empty() {
987 self.result_cache.remove(&portal_name);
988 self.unnamed_portal.replace(portal);
989 } else {
990 assert!(
991 !self.result_cache.contains_key(&portal_name),
992 "Named portal never can be overridden."
993 );
994 self.portal_store.insert(portal_name.clone(), portal);
995 }
996
997 self.statement_portal_dependency
998 .get_mut(&statement_name)
999 .unwrap()
1000 .push(portal_name);
1001
1002 self.stream.write_no_flush(BeMessage::BindComplete)?;
1003 Ok(())
1004 }
1005
1006 async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1007 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1008 let row_max = msg.max_rows as usize;
1009 drop(msg);
1010 let session = self.session.clone().unwrap();
1011
1012 match self.result_cache.remove(&portal_name) {
1013 Some(mut result_cache) => {
1014 assert!(self.portal_store.contains_key(&portal_name));
1015
1016 let is_cosume_completed =
1017 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1018
1019 if !is_cosume_completed {
1020 self.result_cache.insert(portal_name, result_cache);
1021 }
1022 }
1023 _ => {
1024 let portal = self.get_portal(&portal_name)?;
1025 let sql = format!("{}", portal);
1026 let truncated_sql =
1027 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
1028 drop(sql);
1029
1030 session.check_idle_in_transaction_timeout()?;
1031 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1033 let result = session.clone().execute(portal).await;
1034
1035 let pg_response = result.map_err(PsqlError::ExtendedExecuteError)?;
1036 let mut result_cache = ResultCache::new(pg_response);
1037 let is_consume_completed =
1038 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1039 if !is_consume_completed {
1040 self.result_cache.insert(portal_name, result_cache);
1041 }
1042 }
1043 }
1044
1045 Ok(())
1046 }
1047
1048 fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1049 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1050 let session = self.session.clone().unwrap();
1051 assert!(msg.kind == b'S' || msg.kind == b'P');
1055 if msg.kind == b'S' {
1056 let prepare_statement = self.get_statement(&name)?;
1057
1058 let (param_types, row_descriptions) = self
1059 .session
1060 .clone()
1061 .unwrap()
1062 .describe_statement(prepare_statement)
1063 .map_err(PsqlError::Uncategorized)?;
1064 self.stream.write_no_flush(BeMessage::ParameterDescription(
1065 ¶m_types.iter().map(|t| t.to_oid()).collect_vec(),
1066 ))?;
1067
1068 if row_descriptions.is_empty() {
1069 self.stream.write_no_flush(BeMessage::NoData)?;
1072 } else {
1073 self.stream
1074 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1075 }
1076 } else if msg.kind == b'P' {
1077 let portal = self.get_portal(&name)?;
1078
1079 let row_descriptions = session
1080 .describe_portal(portal)
1081 .map_err(PsqlError::Uncategorized)?;
1082
1083 if row_descriptions.is_empty() {
1084 self.stream.write_no_flush(BeMessage::NoData)?;
1087 } else {
1088 self.stream
1089 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1090 }
1091 }
1092 Ok(())
1093 }
1094
1095 fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1096 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1097 assert!(msg.kind == b'S' || msg.kind == b'P');
1098 if msg.kind == b'S' {
1099 if name.is_empty() {
1100 self.unnamed_prepare_statement = None;
1101 } else {
1102 self.prepare_statement_store.remove(&name);
1103 }
1104 for portal_name in self
1105 .statement_portal_dependency
1106 .remove(&name)
1107 .unwrap_or_default()
1108 {
1109 self.remove_portal(&portal_name);
1110 }
1111 } else if msg.kind == b'P' {
1112 self.remove_portal(&name);
1113 }
1114 self.stream.write_no_flush(BeMessage::CloseComplete)?;
1115 Ok(())
1116 }
1117
1118 fn remove_portal(&mut self, portal_name: &str) {
1119 if portal_name.is_empty() {
1120 self.unnamed_portal = None;
1121 } else {
1122 self.portal_store.remove(portal_name);
1123 }
1124 self.result_cache.remove(portal_name);
1125 }
1126
1127 fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1128 if portal_name.is_empty() {
1129 Ok(self
1130 .unnamed_portal
1131 .as_ref()
1132 .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1133 .clone())
1134 } else {
1135 Ok(self
1136 .portal_store
1137 .get(portal_name)
1138 .ok_or_else(|| {
1139 PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1140 })?
1141 .clone())
1142 }
1143 }
1144
1145 fn get_statement(
1146 &self,
1147 statement_name: &str,
1148 ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1149 if statement_name.is_empty() {
1150 Ok(self
1151 .unnamed_prepare_statement
1152 .as_ref()
1153 .ok_or_else(|| {
1154 PsqlError::Uncategorized("unnamed prepare statement not found".into())
1155 })?
1156 .clone())
1157 } else {
1158 Ok(self
1159 .prepare_statement_store
1160 .get(statement_name)
1161 .ok_or_else(|| {
1162 PsqlError::Uncategorized(
1163 format!("Prepare statement {} not found", statement_name).into(),
1164 )
1165 })?
1166 .clone())
1167 }
1168 }
1169}
1170
1171enum PgStreamInner<S> {
1172 Placeholder,
1174 Unencrypted(S),
1176 Ssl(SslStream<S>),
1178}
1179
1180pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1182impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1183
1184pub struct PgStream<S> {
1189 stream: Arc<Mutex<PgStreamInner<S>>>,
1191 write_buf: BytesMut,
1193 read_header: Option<FeMessageHeader>,
1194}
1195
1196impl<S> PgStream<S> {
1197 pub fn new(stream: S) -> Self {
1199 const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1200
1201 Self {
1202 stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1203 write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1204 read_header: None,
1205 }
1206 }
1207
1208 async fn is_ssl_connection(&self) -> bool {
1210 let stream = self.stream.lock().await;
1211 matches!(*stream, PgStreamInner::Ssl(_))
1212 }
1213}
1214
1215impl<S> Clone for PgStream<S> {
1216 fn clone(&self) -> Self {
1217 Self {
1218 stream: Arc::clone(&self.stream),
1219 write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1220 read_header: self.read_header.clone(),
1221 }
1222 }
1223}
1224
1225#[derive(Debug, Default, Clone)]
1242pub struct ParameterStatus {
1243 pub application_name: Option<String>,
1244}
1245
1246impl<S> PgStream<S>
1247where
1248 S: PgByteStream,
1249{
1250 async fn read_startup(&mut self) -> io::Result<FeMessage> {
1251 let mut stream = self.stream.lock().await;
1252 match &mut *stream {
1253 PgStreamInner::Placeholder => unreachable!(),
1254 PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1255 PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1256 }
1257 }
1258
1259 async fn read_header(&mut self) -> io::Result<()> {
1260 let mut stream = self.stream.lock().await;
1261 match &mut *stream {
1262 PgStreamInner::Placeholder => unreachable!(),
1263 PgStreamInner::Unencrypted(stream) => {
1264 self.read_header = Some(FeMessage::read_header(stream).await?);
1265 Ok(())
1266 }
1267 PgStreamInner::Ssl(ssl_stream) => {
1268 self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1269 Ok(())
1270 }
1271 }
1272 }
1273
1274 async fn read_body(&mut self) -> io::Result<FeMessage> {
1275 let mut stream = self.stream.lock().await;
1276 let header = self
1277 .read_header
1278 .take()
1279 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1280 match &mut *stream {
1281 PgStreamInner::Placeholder => unreachable!(),
1282 PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1283 PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1284 }
1285 }
1286
1287 async fn skip_body(&mut self) -> io::Result<()> {
1288 let mut stream = self.stream.lock().await;
1289 let header = self
1290 .read_header
1291 .take()
1292 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1293 match &mut *stream {
1294 PgStreamInner::Placeholder => unreachable!(),
1295 PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1296 PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1297 }
1298 }
1299
1300 fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1301 self.write_no_flush(BeMessage::ParameterStatus(
1302 BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1303 ))?;
1304 self.write_no_flush(BeMessage::ParameterStatus(
1305 BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1306 ))?;
1307 self.write_no_flush(BeMessage::ParameterStatus(
1308 BeParameterStatusMessage::ServerVersion(PG_VERSION),
1309 ))?;
1310 if let Some(application_name) = &status.application_name {
1311 self.write_no_flush(BeMessage::ParameterStatus(
1312 BeParameterStatusMessage::ApplicationName(application_name),
1313 ))?;
1314 }
1315 Ok(())
1316 }
1317
1318 pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1319 BeMessage::write(&mut self.write_buf, message)
1320 }
1321
1322 async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1323 self.write_no_flush(message)?;
1324 self.flush().await?;
1325 Ok(())
1326 }
1327
1328 async fn flush(&mut self) -> io::Result<()> {
1329 let mut stream = self.stream.lock().await;
1330 match &mut *stream {
1331 PgStreamInner::Placeholder => unreachable!(),
1332 PgStreamInner::Unencrypted(stream) => {
1333 stream.write_all(&self.write_buf).await?;
1334 stream.flush().await?;
1335 }
1336 PgStreamInner::Ssl(ssl_stream) => {
1337 ssl_stream.write_all(&self.write_buf).await?;
1338 ssl_stream.flush().await?;
1339 }
1340 }
1341 self.write_buf.clear();
1342 Ok(())
1343 }
1344}
1345
1346impl<S> PgStream<S>
1347where
1348 S: PgByteStream,
1349{
1350 async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1352 let mut stream = self.stream.lock().await;
1353
1354 match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1355 PgStreamInner::Unencrypted(unencrypted_stream) => {
1356 let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1357 let mut ssl_stream =
1358 tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1359
1360 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1361 tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1362 let _ = ssl_stream.shutdown().await;
1363 return Err(e.into());
1364 }
1365
1366 *stream = PgStreamInner::Ssl(ssl_stream);
1367 }
1368 PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1369 PgStreamInner::Placeholder => unreachable!(),
1370 }
1371
1372 Ok(())
1373 }
1374}
1375
1376fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1377 let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1378
1379 let key_path = &tls_config.key;
1380 let cert_path = &tls_config.cert;
1381
1382 acceptor
1385 .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1386 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1387 acceptor
1388 .set_ca_file(cert_path)
1389 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1390 acceptor
1391 .set_certificate_chain_file(cert_path)
1392 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1393 let acceptor = acceptor.build();
1394
1395 Ok(acceptor.into_context())
1396}
1397
1398pub mod truncated_fmt {
1399 use std::fmt::*;
1400
1401 struct TruncatedFormatter<'a, 'b> {
1402 remaining: usize,
1403 finished: bool,
1404 f: &'a mut Formatter<'b>,
1405 }
1406 impl Write for TruncatedFormatter<'_, '_> {
1407 fn write_str(&mut self, s: &str) -> Result {
1408 if self.finished {
1409 return Ok(());
1410 }
1411
1412 if self.remaining < s.len() {
1413 let actual = s.floor_char_boundary(self.remaining);
1414 self.f.write_str(&s[0..actual])?;
1415 self.remaining -= actual;
1416 self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1417 self.finished = true; } else {
1419 self.f.write_str(s)?;
1420 self.remaining -= s.len();
1421 }
1422 Ok(())
1423 }
1424 }
1425
1426 pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1427
1428 impl<T> Debug for TruncatedFmt<'_, T>
1429 where
1430 T: Debug,
1431 {
1432 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1433 TruncatedFormatter {
1434 remaining: self.1,
1435 finished: false,
1436 f,
1437 }
1438 .write_fmt(format_args!("{:?}", self.0))
1439 }
1440 }
1441
1442 impl<T> Display for TruncatedFmt<'_, T>
1443 where
1444 T: Display,
1445 {
1446 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1447 TruncatedFormatter {
1448 remaining: self.1,
1449 finished: false,
1450 f,
1451 }
1452 .write_fmt(format_args!("{}", self.0))
1453 }
1454 }
1455
1456 #[cfg(test)]
1457 mod tests {
1458 use super::*;
1459
1460 #[test]
1461 fn test_trunc_utf8() {
1462 assert_eq!(
1463 format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1464 "select '...(truncated,14)",
1465 );
1466 }
1467 }
1468}
1469
1470#[cfg(test)]
1471mod tests {
1472 use std::collections::HashSet;
1473
1474 use super::*;
1475
1476 #[test]
1477 fn test_redact_parsable_sql() {
1478 let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1479 let sql = r"
1480 create source temp (k bigint, v varchar) with (
1481 connector = 'datagen',
1482 v1 = 123,
1483 v2 = 'with',
1484 v3 = false,
1485 v4 = '',
1486 ) FORMAT plain ENCODE json (a='1',b='2')
1487 ";
1488 assert_eq!(
1489 redact_sql(sql, keywords),
1490 "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1491 );
1492 }
1493}