1use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
46use crate::net::AddressRef;
47use crate::pg_extended::ResultCache;
48use crate::pg_message::{
49 BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
50 FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
51 FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
52};
53use crate::pg_server::{Session, SessionManager, UserAuthenticator};
54use crate::types::Format;
55
56static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
59 LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
60 Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
61 _ => {
62 if cfg!(debug_assertions) {
63 65536
64 } else {
65 1024
66 }
67 }
68 });
69
70tokio::task_local! {
71 pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
73}
74
75pub struct PgProtocol<S, SM>
78where
79 SM: SessionManager,
80{
81 stream: PgStream<S>,
83 state: PgProtocolState,
85 is_terminate: bool,
87
88 session_mgr: Arc<SM>,
89 session: Option<Arc<SM::Session>>,
90
91 result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
92 unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
93 prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
94 unnamed_portal: Option<<SM::Session as Session>::Portal>,
95 portal_store: HashMap<String, <SM::Session as Session>::Portal>,
96 statement_portal_dependency: HashMap<String, Vec<String>>,
99
100 tls_context: Option<SslContext>,
103
104 tls_config: Option<TlsConfig>,
106
107 ignore_util_sync: bool,
110
111 peer_addr: AddressRef,
113
114 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
115 message_memory_manager: MessageMemoryManagerRef,
116}
117
118#[derive(Debug, Clone)]
120pub struct TlsConfig {
121 pub cert: String,
123 pub key: String,
125 pub enforce_ssl: bool,
127}
128
129impl TlsConfig {
130 pub fn new_default() -> Option<Self> {
131 let cert = std::env::var("RW_SSL_CERT").ok()?;
132 let key = std::env::var("RW_SSL_KEY").ok()?;
133 let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
134 tracing::info!(
135 "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
136 cert,
137 key,
138 enforce_ssl
139 );
140 Some(Self {
141 cert,
142 key,
143 enforce_ssl,
144 })
145 }
146}
147
148impl<S, SM> Drop for PgProtocol<S, SM>
149where
150 SM: SessionManager,
151{
152 fn drop(&mut self) {
153 if let Some(session) = &self.session {
154 self.session_mgr.end_session(session);
156 }
157 }
158}
159
160enum PgProtocolState {
162 Startup,
163 Regular,
164}
165
166pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
170 let without_null = if b.last() == Some(&0) {
171 &b[..b.len() - 1]
172 } else {
173 &b[..]
174 };
175 std::str::from_utf8(without_null)
176}
177
178fn record_sql_in_span(
180 sql: &str,
181 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
182) -> String {
183 let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
184 && !keywords.is_empty()
185 {
186 redact_sql(sql, keywords)
187 } else {
188 sql.to_owned()
189 };
190 let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
191 tracing::Span::current().record("sql", tracing::field::display(&truncated));
192 truncated.to_string()
193}
194
195fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
197 match Parser::parse_sql(sql) {
198 Ok(sqls) => sqls
199 .into_iter()
200 .map(|sql| sql.to_redacted_string(keywords.clone()))
201 .join(";"),
202 Err(_) => sql.to_owned(),
203 }
204}
205
206#[derive(Clone)]
207pub struct ConnectionContext {
208 pub tls_config: Option<TlsConfig>,
209 pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
210 pub message_memory_manager: MessageMemoryManagerRef,
211}
212
213impl<S, SM> PgProtocol<S, SM>
214where
215 S: PgByteStream,
216 SM: SessionManager,
217{
218 pub fn new(
219 stream: S,
220 session_mgr: Arc<SM>,
221 peer_addr: AddressRef,
222 context: ConnectionContext,
223 ) -> Self {
224 let ConnectionContext {
225 tls_config,
226 redact_sql_option_keywords,
227 message_memory_manager,
228 } = context;
229 Self {
230 stream: PgStream::new(stream),
231 is_terminate: false,
232 state: PgProtocolState::Startup,
233 session_mgr,
234 session: None,
235 tls_context: tls_config
236 .as_ref()
237 .and_then(|e| build_ssl_ctx_from_config(e).ok()),
238 tls_config,
239 result_cache: Default::default(),
240 unnamed_prepare_statement: Default::default(),
241 prepare_statement_store: Default::default(),
242 unnamed_portal: Default::default(),
243 portal_store: Default::default(),
244 statement_portal_dependency: Default::default(),
245 ignore_util_sync: false,
246 peer_addr,
247 redact_sql_option_keywords,
248 message_memory_manager,
249 }
250 }
251
252 pub async fn run(&mut self) {
254 let mut notice_fut = None;
255
256 loop {
257 if notice_fut.is_none()
259 && let Some(session) = self.session.clone()
260 {
261 let mut stream = self.stream.clone();
262 notice_fut = Some(Box::pin(async move {
263 loop {
264 let notice = session.next_notice().await;
265 if let Err(e) = stream.write(BeMessage::NoticeResponse(¬ice)).await {
266 tracing::error!(error = %e.as_report(), notice, "failed to send notice");
267 }
268 }
269 }));
270 }
271
272 let process = std::pin::pin!(async {
274 let (msg, _memory_guard) = match self.read_message().await {
275 Ok(msg) => msg,
276 Err(e) => {
277 tracing::error!(error = %e.as_report(), "error when reading message");
278 return true; }
280 };
281 tracing::trace!(?msg, "received message");
282 self.process(msg).await
283 });
284
285 let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
286 tokio::select! {
287 _ = notice_fut => unreachable!(),
288 terminated = process => terminated,
289 }
290 } else {
291 process.await
292 };
293
294 if terminated {
295 break;
296 }
297 }
298 }
299
300 pub async fn process(&mut self, msg: FeMessage) -> bool {
302 self.do_process(msg).await.is_none() || self.is_terminate
303 }
304
305 fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
313 let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
314 return tracing::Span::none();
315 };
316
317 let mode = match msg {
318 FeMessage::Query(_) => "simple query",
319 FeMessage::Parse(_) => "extended query parse",
320 FeMessage::Execute(_) => "extended query execute",
321 _ => return tracing::Span::none(),
322 };
323
324 tracing::info_span!(
325 target: PGWIRE_ROOT_SPAN_TARGET,
326 "handle_query",
327 mode,
328 session_id,
329 sql = tracing::field::Empty, )
331 }
332
333 async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
337 let span = self.root_span_for_msg(&msg);
338 let weak_session = self
339 .session
340 .as_ref()
341 .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
342
343 let fut = Box::pin(self.do_process_inner(msg));
348
349 let fut = async move {
351 if let Some(session) = weak_session {
352 CURRENT_SESSION.scope(session, fut).await
353 } else {
354 fut.await
355 }
356 };
357
358 let fut = async move {
360 AssertUnwindSafe(fut)
361 .rw_catch_unwind()
362 .await
363 .unwrap_or_else(|payload| {
364 Err(PsqlError::Panic(
365 panic_message::panic_message(&payload).to_owned(),
366 ))
367 })
368 };
369
370 let fut = async move {
372 let period = *SLOW_QUERY_LOG_PERIOD;
373 let mut fut = std::pin::pin!(fut);
374 let mut elapsed = Duration::ZERO;
375
376 loop {
378 match tokio::time::timeout(period, &mut fut).await {
379 Ok(result) => break result,
380 Err(_) => {
381 elapsed += period;
382 tracing::info!(
383 target: PGWIRE_SLOW_QUERY_LOG,
384 elapsed = %format_args!("{}ms", elapsed.as_millis()),
385 "slow query"
386 );
387 }
388 }
389 }
390 };
391
392 let fut = async move {
394 let start = Instant::now();
395 let result = fut.await;
396 let elapsed = start.elapsed();
397
398 if let Err(error) = &result {
402 if cfg!(debug_assertions) && !Deployment::current().is_ci() {
403 tracing::error!(error = ?error.as_report(), "error when process message");
409 } else {
410 tracing::error!(error = %error.as_report(), "error when process message");
411 }
412 }
413
414 if !tracing::Span::current().is_none() {
417 tracing::info!(
418 target: PGWIRE_QUERY_LOG,
419 status = if result.is_ok() { "ok" } else { "err" },
420 time = %format_args!("{}ms", elapsed.as_millis()),
421 );
422 }
423
424 result
425 };
426
427 let fut = fut.instrument(span);
429
430 match fut.await {
432 Ok(()) => Some(()),
433 Err(e) => {
434 match e {
435 PsqlError::IoError(io_err) => {
436 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
437 return None;
438 }
439 }
440
441 PsqlError::SslError(_) => {
442 return None;
445 }
446
447 PsqlError::StartupError(_) | PsqlError::PasswordError => {
448 self.stream
449 .write_no_flush(BeMessage::ErrorResponse(&e))
450 .ok()?;
451 let _ = self.stream.flush().await;
452 return None;
453 }
454
455 PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
456 self.stream
457 .write_no_flush(BeMessage::ErrorResponse(&e))
458 .ok()?;
459 self.ready_for_query().ok()?;
460 }
461
462 PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
463 self.stream
464 .write_no_flush(BeMessage::ErrorResponse(&e))
465 .ok()?;
466 let _ = self.stream.flush().await;
467
468 return None;
473 }
474
475 PsqlError::Uncategorized(_)
476 | PsqlError::ExtendedPrepareError(_)
477 | PsqlError::ExtendedExecuteError(_) => {
478 self.stream
479 .write_no_flush(BeMessage::ErrorResponse(&e))
480 .ok()?;
481 }
482 }
483 let _ = self.stream.flush().await;
484 Some(())
485 }
486 }
487 }
488
489 async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
490 if self.ignore_util_sync {
492 if let FeMessage::Sync = msg {
493 } else {
494 tracing::trace!("ignore message {:?} until sync.", msg);
495 return Ok(());
496 }
497 }
498
499 match msg {
500 FeMessage::Gss => self.process_gss_msg().await?,
501 FeMessage::Ssl => self.process_ssl_msg().await?,
502 FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
503 FeMessage::Password(msg) => self.process_password_msg(msg).await?,
504 FeMessage::Query(query_msg) => {
505 let sql = Arc::from(query_msg.get_sql()?);
506 drop(query_msg);
508 self.process_query_msg(sql).await?
509 }
510 FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
511 FeMessage::Terminate => self.process_terminate(),
512 FeMessage::Parse(m) => {
513 if let Err(err) = self.process_parse_msg(m).await {
514 self.ignore_util_sync = true;
515 return Err(err);
516 }
517 }
518 FeMessage::Bind(m) => {
519 if let Err(err) = self.process_bind_msg(m) {
520 self.ignore_util_sync = true;
521 return Err(err);
522 }
523 }
524 FeMessage::Execute(m) => {
525 if let Err(err) = self.process_execute_msg(m).await {
526 self.ignore_util_sync = true;
527 return Err(err);
528 }
529 }
530 FeMessage::Describe(m) => {
531 if let Err(err) = self.process_describe_msg(m) {
532 self.ignore_util_sync = true;
533 return Err(err);
534 }
535 }
536 FeMessage::Sync => {
537 self.ignore_util_sync = false;
538 self.ready_for_query()?
539 }
540 FeMessage::Close(m) => {
541 if let Err(err) = self.process_close_msg(m) {
542 self.ignore_util_sync = true;
543 return Err(err);
544 }
545 }
546 FeMessage::Flush => {
547 if let Err(err) = self.stream.flush().await {
548 self.ignore_util_sync = true;
549 return Err(err.into());
550 }
551 }
552 FeMessage::HealthCheck => self.process_health_check(),
553 FeMessage::ServerThrottle(reason) => match reason {
554 ServerThrottleReason::TooLargeMessage => {
555 return Err(PsqlError::ServerThrottle(format!(
556 "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
557 self.message_memory_manager.max_filter_bytes
558 )));
559 }
560 ServerThrottleReason::TooManyMemoryUsage => {
561 return Err(PsqlError::ServerThrottle(format!(
562 "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
563 self.message_memory_manager.max_running_bytes
564 )));
565 }
566 },
567 }
568 self.stream.flush().await?;
569 Ok(())
570 }
571
572 pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
573 match self.state {
574 PgProtocolState::Startup => self
575 .stream
576 .read_startup()
577 .await
578 .map(|message: FeMessage| (message, None)),
579 PgProtocolState::Regular => {
580 self.stream.read_header().await?;
581 let guard = if let Some(ref header) = self.stream.read_header {
582 let payload_len = std::cmp::max(header.payload_len, 0) as u64;
583 let (reason, guard) = self.message_memory_manager.add(payload_len);
584 if let Some(reason) = reason {
585 drop(guard);
587 self.stream.skip_body().await?;
588 return Ok((FeMessage::ServerThrottle(reason), None));
589 }
590 guard
591 } else {
592 None
593 };
594 let message = self.stream.read_body().await?;
595 Ok((message, guard))
596 }
597 }
598 }
599
600 fn ready_for_query(&mut self) -> io::Result<()> {
602 self.stream.write_no_flush(BeMessage::ReadyForQuery(
603 self.session
604 .as_ref()
605 .map(|s| s.transaction_status())
606 .unwrap_or(TransactionStatus::Idle),
607 ))
608 }
609
610 async fn process_gss_msg(&mut self) -> PsqlResult<()> {
611 self.stream.write(BeMessage::EncryptionResponseNo).await?;
613 Ok(())
614 }
615
616 async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
617 if let Some(context) = self.tls_context.as_ref() {
618 self.stream.write(BeMessage::EncryptionResponseSsl).await?;
621 self.stream.upgrade_to_ssl(context).await?;
622 } else {
623 self.stream.write(BeMessage::EncryptionResponseNo).await?;
625 }
626
627 Ok(())
628 }
629
630 async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
631 if let Some(ref tls_config) = self.tls_config
633 && tls_config.enforce_ssl
634 && !self.stream.is_ssl_connection().await
635 {
636 return Err(PsqlError::StartupError(
637 "SSL connection is required but not established".into(),
638 ));
639 }
640
641 let db_name = msg
642 .config
643 .get("database")
644 .cloned()
645 .unwrap_or_else(|| "dev".to_owned());
646 let user_name = msg
647 .config
648 .get("user")
649 .cloned()
650 .unwrap_or_else(|| "root".to_owned());
651
652 let session = self
653 .session_mgr
654 .connect(&db_name, &user_name, self.peer_addr.clone())
655 .map_err(PsqlError::StartupError)?;
656
657 let application_name = msg.config.get("application_name");
658 if let Some(application_name) = application_name {
659 session
660 .set_config("application_name", application_name.clone())
661 .map_err(PsqlError::StartupError)?;
662 }
663
664 match session.user_authenticator() {
665 UserAuthenticator::None => {
666 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
667
668 self.stream
671 .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
672
673 self.stream.write_no_flush(BeMessage::ParameterStatus(
674 BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
675 ))?;
676 self.stream
677 .write_parameter_status_msg_no_flush(&ParameterStatus {
678 application_name: application_name.cloned(),
679 })?;
680 self.ready_for_query()?;
681 }
682 UserAuthenticator::ClearText(_)
683 | UserAuthenticator::OAuth { .. }
684 | UserAuthenticator::Ldap(..) => {
685 self.stream
686 .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
687 }
688 UserAuthenticator::Md5WithSalt { salt, .. } => {
689 self.stream
690 .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
691 }
692 }
693
694 self.session = Some(session);
695 self.state = PgProtocolState::Regular;
696 Ok(())
697 }
698
699 async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
700 let session = self.session.as_ref().unwrap();
701 let authenticator = session.user_authenticator();
702 authenticator.authenticate(&msg.password).await?;
703 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
704 self.stream.write_no_flush(BeMessage::ParameterStatus(
705 BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
706 ))?;
707 self.stream
708 .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
709 self.ready_for_query()?;
710 self.state = PgProtocolState::Regular;
711 Ok(())
712 }
713
714 fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
715 let session_id = (m.target_process_id, m.target_secret_key);
716 tracing::trace!("cancel query in session: {:?}", session_id);
717 self.session_mgr.cancel_queries_in_session(session_id);
718 self.session_mgr.cancel_creating_jobs_in_session(session_id);
719 self.is_terminate = true;
720 Ok(())
721 }
722
723 async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
724 let truncated_sql = record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
725 let session = self.session.clone().unwrap();
726
727 session.check_idle_in_transaction_timeout()?;
728 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
730 self.inner_process_query_msg(sql, session.clone()).await
731 }
732
733 async fn inner_process_query_msg(
734 &mut self,
735 sql: Arc<str>,
736 session: Arc<SM::Session>,
737 ) -> PsqlResult<()> {
738 let stmts =
740 Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
741 drop(sql);
743 if stmts.is_empty() {
744 self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
745 }
746
747 for stmt in stmts {
749 self.inner_process_query_msg_one_stmt(stmt, session.clone())
750 .await?;
751 }
752 self.ready_for_query()?;
755 Ok(())
756 }
757
758 async fn inner_process_query_msg_one_stmt(
759 &mut self,
760 stmt: Statement,
761 session: Arc<SM::Session>,
762 ) -> PsqlResult<()> {
763 let session = session.clone();
764
765 let res = session.clone().run_one_query(stmt, Format::Text).await;
767
768 while let Some(notice) = session.next_notice().now_or_never() {
770 self.stream
771 .write_no_flush(BeMessage::NoticeResponse(¬ice))?;
772 }
773
774 let mut res = res.map_err(PsqlError::SimpleQueryError)?;
775
776 for notice in res.notices() {
777 self.stream
778 .write_no_flush(BeMessage::NoticeResponse(notice))?;
779 }
780
781 let status = res.status();
782 if let Some(ref application_name) = status.application_name {
783 self.stream.write_no_flush(BeMessage::ParameterStatus(
784 BeParameterStatusMessage::ApplicationName(application_name),
785 ))?;
786 }
787
788 if res.is_copy_query_to_stdout() {
789 self.stream
790 .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
791 let mut count = 0;
792 while let Some(row_set) = res.values_stream().next().await {
793 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
794 for row in row_set {
795 self.stream.write_no_flush(BeMessage::CopyData(&row))?;
796 count += 1;
797 }
798 }
799
800 self.stream.write_no_flush(BeMessage::CopyDone)?;
801
802 res.run_callback().await?;
804
805 self.stream
806 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
807 stmt_type: res.stmt_type(),
808 rows_cnt: count,
809 }))?;
810 } else if res.is_query() {
811 self.stream
812 .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
813
814 let mut rows_cnt = 0;
815
816 while let Some(row_set) = res.values_stream().next().await {
817 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
818 for row in row_set {
819 self.stream.write_no_flush(BeMessage::DataRow(&row))?;
820 rows_cnt += 1;
821 }
822 }
823
824 res.run_callback().await?;
826
827 self.stream
828 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
829 stmt_type: res.stmt_type(),
830 rows_cnt,
831 }))?;
832 } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
833 let first_row_set = res.values_stream().next().await;
834 let first_row_set = match first_row_set {
835 None => {
836 return Err(PsqlError::Uncategorized(
837 anyhow::anyhow!("no affected rows in output").into(),
838 ));
839 }
840 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
841 };
842 let affected_rows_str = first_row_set[0].values()[0]
843 .as_ref()
844 .expect("compute node should return affected rows in output");
845
846 assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
847 let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
848 .unwrap()
849 .parse()
850 .unwrap_or_default();
851
852 res.run_callback().await?;
854
855 self.stream
856 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
857 stmt_type: res.stmt_type(),
858 rows_cnt: affected_rows_cnt,
859 }))?;
860 } else {
861 res.run_callback().await?;
863
864 self.stream
865 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
866 stmt_type: res.stmt_type(),
867 rows_cnt: 0,
868 }))?;
869 }
870
871 Ok(())
872 }
873
874 fn process_terminate(&mut self) {
875 self.is_terminate = true;
876 }
877
878 fn process_health_check(&mut self) {
879 tracing::debug!("health check");
880 self.is_terminate = true;
881 }
882
883 async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
884 let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
885 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
886 let session = self.session.clone().unwrap();
887 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
888 let type_ids = std::mem::take(&mut msg.type_ids);
889 drop(msg);
891 self.inner_process_parse_msg(session, sql, statement_name, type_ids)
892 .await?;
893 Ok(())
894 }
895
896 async fn inner_process_parse_msg(
897 &mut self,
898 session: Arc<SM::Session>,
899 sql: Arc<str>,
900 statement_name: String,
901 type_ids: Vec<i32>,
902 ) -> PsqlResult<()> {
903 if statement_name.is_empty() {
904 self.unnamed_prepare_statement.take();
907 } else if self.prepare_statement_store.contains_key(&statement_name) {
908 return Err(PsqlError::ExtendedPrepareError(
909 "Duplicated statement name".into(),
910 ));
911 }
912
913 let stmt = {
914 let stmts = Parser::parse_sql(&sql)
915 .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
916 drop(sql);
917 if stmts.len() > 1 {
918 return Err(PsqlError::ExtendedPrepareError(
919 "Only one statement is allowed in extended query mode".into(),
920 ));
921 }
922
923 stmts.into_iter().next()
924 };
925
926 let param_types: Vec<Option<DataType>> = type_ids
927 .iter()
928 .map(|&id| {
929 if id == 0 {
932 Ok(None)
933 } else {
934 DataType::from_oid(id)
935 .map(Some)
936 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
937 }
938 })
939 .try_collect()?;
940
941 let prepare_statement = session
942 .parse(stmt, param_types)
943 .await
944 .map_err(PsqlError::ExtendedPrepareError)?;
945
946 if statement_name.is_empty() {
947 self.unnamed_prepare_statement.replace(prepare_statement);
948 } else {
949 self.prepare_statement_store
950 .insert(statement_name.clone(), prepare_statement);
951 }
952
953 self.statement_portal_dependency
954 .entry(statement_name)
955 .or_default()
956 .clear();
957
958 self.stream.write_no_flush(BeMessage::ParseComplete)?;
959 Ok(())
960 }
961
962 fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
963 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
964 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
965 let session = self.session.clone().unwrap();
966
967 if self.portal_store.contains_key(&portal_name) {
968 return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
969 }
970
971 let prepare_statement = self.get_statement(&statement_name)?;
972
973 let result_formats = msg
974 .result_format_codes
975 .iter()
976 .map(|&format_code| Format::from_i16(format_code))
977 .try_collect()?;
978 let param_formats = msg
979 .param_format_codes
980 .iter()
981 .map(|&format_code| Format::from_i16(format_code))
982 .try_collect()?;
983
984 let portal = session
985 .bind(prepare_statement, msg.params, param_formats, result_formats)
986 .map_err(PsqlError::Uncategorized)?;
987
988 if portal_name.is_empty() {
989 self.result_cache.remove(&portal_name);
990 self.unnamed_portal.replace(portal);
991 } else {
992 assert!(
993 !self.result_cache.contains_key(&portal_name),
994 "Named portal never can be overridden."
995 );
996 self.portal_store.insert(portal_name.clone(), portal);
997 }
998
999 self.statement_portal_dependency
1000 .get_mut(&statement_name)
1001 .unwrap()
1002 .push(portal_name);
1003
1004 self.stream.write_no_flush(BeMessage::BindComplete)?;
1005 Ok(())
1006 }
1007
1008 async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1009 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1010 let row_max = msg.max_rows as usize;
1011 drop(msg);
1012 let session = self.session.clone().unwrap();
1013
1014 match self.result_cache.remove(&portal_name) {
1015 Some(mut result_cache) => {
1016 assert!(self.portal_store.contains_key(&portal_name));
1017
1018 let is_consume_completed =
1019 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1020
1021 if !is_consume_completed {
1022 self.result_cache.insert(portal_name, result_cache);
1023 }
1024 }
1025 _ => {
1026 let portal = self.get_portal(&portal_name)?;
1027 let sql = format!("{}", portal);
1028 let truncated_sql =
1029 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
1030 drop(sql);
1031
1032 session.check_idle_in_transaction_timeout()?;
1033 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1035 let result = session.clone().execute(portal).await;
1036
1037 let pg_response = result.map_err(PsqlError::ExtendedExecuteError)?;
1038 let mut result_cache = ResultCache::new(pg_response);
1039 let is_consume_completed =
1040 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1041 if !is_consume_completed {
1042 self.result_cache.insert(portal_name, result_cache);
1043 }
1044 }
1045 }
1046
1047 Ok(())
1048 }
1049
1050 fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1051 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1052 let session = self.session.clone().unwrap();
1053 assert!(msg.kind == b'S' || msg.kind == b'P');
1057 if msg.kind == b'S' {
1058 let prepare_statement = self.get_statement(&name)?;
1059
1060 let (param_types, row_descriptions) = self
1061 .session
1062 .clone()
1063 .unwrap()
1064 .describe_statement(prepare_statement)
1065 .map_err(PsqlError::Uncategorized)?;
1066 self.stream.write_no_flush(BeMessage::ParameterDescription(
1067 ¶m_types.iter().map(|t| t.to_oid()).collect_vec(),
1068 ))?;
1069
1070 if row_descriptions.is_empty() {
1071 self.stream.write_no_flush(BeMessage::NoData)?;
1074 } else {
1075 self.stream
1076 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1077 }
1078 } else if msg.kind == b'P' {
1079 let portal = self.get_portal(&name)?;
1080
1081 let row_descriptions = session
1082 .describe_portal(portal)
1083 .map_err(PsqlError::Uncategorized)?;
1084
1085 if row_descriptions.is_empty() {
1086 self.stream.write_no_flush(BeMessage::NoData)?;
1089 } else {
1090 self.stream
1091 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1092 }
1093 }
1094 Ok(())
1095 }
1096
1097 fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1098 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1099 assert!(msg.kind == b'S' || msg.kind == b'P');
1100 if msg.kind == b'S' {
1101 if name.is_empty() {
1102 self.unnamed_prepare_statement = None;
1103 } else {
1104 self.prepare_statement_store.remove(&name);
1105 }
1106 for portal_name in self
1107 .statement_portal_dependency
1108 .remove(&name)
1109 .unwrap_or_default()
1110 {
1111 self.remove_portal(&portal_name);
1112 }
1113 } else if msg.kind == b'P' {
1114 self.remove_portal(&name);
1115 }
1116 self.stream.write_no_flush(BeMessage::CloseComplete)?;
1117 Ok(())
1118 }
1119
1120 fn remove_portal(&mut self, portal_name: &str) {
1121 if portal_name.is_empty() {
1122 self.unnamed_portal = None;
1123 } else {
1124 self.portal_store.remove(portal_name);
1125 }
1126 self.result_cache.remove(portal_name);
1127 }
1128
1129 fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1130 if portal_name.is_empty() {
1131 Ok(self
1132 .unnamed_portal
1133 .as_ref()
1134 .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1135 .clone())
1136 } else {
1137 Ok(self
1138 .portal_store
1139 .get(portal_name)
1140 .ok_or_else(|| {
1141 PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1142 })?
1143 .clone())
1144 }
1145 }
1146
1147 fn get_statement(
1148 &self,
1149 statement_name: &str,
1150 ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1151 if statement_name.is_empty() {
1152 Ok(self
1153 .unnamed_prepare_statement
1154 .as_ref()
1155 .ok_or_else(|| {
1156 PsqlError::Uncategorized("unnamed prepare statement not found".into())
1157 })?
1158 .clone())
1159 } else {
1160 Ok(self
1161 .prepare_statement_store
1162 .get(statement_name)
1163 .ok_or_else(|| {
1164 PsqlError::Uncategorized(
1165 format!("Prepare statement {} not found", statement_name).into(),
1166 )
1167 })?
1168 .clone())
1169 }
1170 }
1171}
1172
1173enum PgStreamInner<S> {
1174 Placeholder,
1176 Unencrypted(S),
1178 Ssl(SslStream<S>),
1180}
1181
1182pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1184impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1185
1186pub struct PgStream<S> {
1191 stream: Arc<Mutex<PgStreamInner<S>>>,
1193 write_buf: BytesMut,
1195 read_header: Option<FeMessageHeader>,
1196}
1197
1198impl<S> PgStream<S> {
1199 pub fn new(stream: S) -> Self {
1201 const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1202
1203 Self {
1204 stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1205 write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1206 read_header: None,
1207 }
1208 }
1209
1210 async fn is_ssl_connection(&self) -> bool {
1212 let stream = self.stream.lock().await;
1213 matches!(*stream, PgStreamInner::Ssl(_))
1214 }
1215}
1216
1217impl<S> Clone for PgStream<S> {
1218 fn clone(&self) -> Self {
1219 Self {
1220 stream: Arc::clone(&self.stream),
1221 write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1222 read_header: self.read_header.clone(),
1223 }
1224 }
1225}
1226
1227#[derive(Debug, Default, Clone)]
1244pub struct ParameterStatus {
1245 pub application_name: Option<String>,
1246}
1247
1248impl<S> PgStream<S>
1249where
1250 S: PgByteStream,
1251{
1252 async fn read_startup(&mut self) -> io::Result<FeMessage> {
1253 let mut stream = self.stream.lock().await;
1254 match &mut *stream {
1255 PgStreamInner::Placeholder => unreachable!(),
1256 PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1257 PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1258 }
1259 }
1260
1261 async fn read_header(&mut self) -> io::Result<()> {
1262 let mut stream = self.stream.lock().await;
1263 match &mut *stream {
1264 PgStreamInner::Placeholder => unreachable!(),
1265 PgStreamInner::Unencrypted(stream) => {
1266 self.read_header = Some(FeMessage::read_header(stream).await?);
1267 Ok(())
1268 }
1269 PgStreamInner::Ssl(ssl_stream) => {
1270 self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1271 Ok(())
1272 }
1273 }
1274 }
1275
1276 async fn read_body(&mut self) -> io::Result<FeMessage> {
1277 let mut stream = self.stream.lock().await;
1278 let header = self
1279 .read_header
1280 .take()
1281 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1282 match &mut *stream {
1283 PgStreamInner::Placeholder => unreachable!(),
1284 PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1285 PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1286 }
1287 }
1288
1289 async fn skip_body(&mut self) -> io::Result<()> {
1290 let mut stream = self.stream.lock().await;
1291 let header = self
1292 .read_header
1293 .take()
1294 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1295 match &mut *stream {
1296 PgStreamInner::Placeholder => unreachable!(),
1297 PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1298 PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1299 }
1300 }
1301
1302 fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1303 self.write_no_flush(BeMessage::ParameterStatus(
1304 BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1305 ))?;
1306 self.write_no_flush(BeMessage::ParameterStatus(
1307 BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1308 ))?;
1309 self.write_no_flush(BeMessage::ParameterStatus(
1310 BeParameterStatusMessage::ServerVersion(PG_VERSION),
1311 ))?;
1312 if let Some(application_name) = &status.application_name {
1313 self.write_no_flush(BeMessage::ParameterStatus(
1314 BeParameterStatusMessage::ApplicationName(application_name),
1315 ))?;
1316 }
1317 Ok(())
1318 }
1319
1320 pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1321 BeMessage::write(&mut self.write_buf, message)
1322 }
1323
1324 async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1325 self.write_no_flush(message)?;
1326 self.flush().await?;
1327 Ok(())
1328 }
1329
1330 async fn flush(&mut self) -> io::Result<()> {
1331 let mut stream = self.stream.lock().await;
1332 match &mut *stream {
1333 PgStreamInner::Placeholder => unreachable!(),
1334 PgStreamInner::Unencrypted(stream) => {
1335 stream.write_all(&self.write_buf).await?;
1336 stream.flush().await?;
1337 }
1338 PgStreamInner::Ssl(ssl_stream) => {
1339 ssl_stream.write_all(&self.write_buf).await?;
1340 ssl_stream.flush().await?;
1341 }
1342 }
1343 self.write_buf.clear();
1344 Ok(())
1345 }
1346}
1347
1348impl<S> PgStream<S>
1349where
1350 S: PgByteStream,
1351{
1352 async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1354 let mut stream = self.stream.lock().await;
1355
1356 match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1357 PgStreamInner::Unencrypted(unencrypted_stream) => {
1358 let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1359 let mut ssl_stream =
1360 tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1361
1362 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1363 tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1364 let _ = ssl_stream.shutdown().await;
1365 return Err(e.into());
1366 }
1367
1368 *stream = PgStreamInner::Ssl(ssl_stream);
1369 }
1370 PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1371 PgStreamInner::Placeholder => unreachable!(),
1372 }
1373
1374 Ok(())
1375 }
1376}
1377
1378fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1379 let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1380
1381 let key_path = &tls_config.key;
1382 let cert_path = &tls_config.cert;
1383
1384 acceptor
1387 .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1388 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1389 acceptor
1390 .set_ca_file(cert_path)
1391 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1392 acceptor
1393 .set_certificate_chain_file(cert_path)
1394 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1395 let acceptor = acceptor.build();
1396
1397 Ok(acceptor.into_context())
1398}
1399
1400pub mod truncated_fmt {
1401 use std::fmt::*;
1402
1403 struct TruncatedFormatter<'a, 'b> {
1404 remaining: usize,
1405 finished: bool,
1406 f: &'a mut Formatter<'b>,
1407 }
1408 impl Write for TruncatedFormatter<'_, '_> {
1409 fn write_str(&mut self, s: &str) -> Result {
1410 if self.finished {
1411 return Ok(());
1412 }
1413
1414 if self.remaining < s.len() {
1415 let actual = s.floor_char_boundary(self.remaining);
1416 self.f.write_str(&s[0..actual])?;
1417 self.remaining -= actual;
1418 self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1419 self.finished = true; } else {
1421 self.f.write_str(s)?;
1422 self.remaining -= s.len();
1423 }
1424 Ok(())
1425 }
1426 }
1427
1428 pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1429
1430 impl<T> Debug for TruncatedFmt<'_, T>
1431 where
1432 T: Debug,
1433 {
1434 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1435 TruncatedFormatter {
1436 remaining: self.1,
1437 finished: false,
1438 f,
1439 }
1440 .write_fmt(format_args!("{:?}", self.0))
1441 }
1442 }
1443
1444 impl<T> Display for TruncatedFmt<'_, T>
1445 where
1446 T: Display,
1447 {
1448 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1449 TruncatedFormatter {
1450 remaining: self.1,
1451 finished: false,
1452 f,
1453 }
1454 .write_fmt(format_args!("{}", self.0))
1455 }
1456 }
1457
1458 #[cfg(test)]
1459 mod tests {
1460 use super::*;
1461
1462 #[test]
1463 fn test_trunc_utf8() {
1464 assert_eq!(
1465 format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1466 "select '...(truncated,14)",
1467 );
1468 }
1469 }
1470}
1471
1472#[cfg(test)]
1473mod tests {
1474 use std::collections::HashSet;
1475
1476 use super::*;
1477
1478 #[test]
1479 fn test_redact_parsable_sql() {
1480 let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1481 let sql = r"
1482 create source temp (k bigint, v varchar) with (
1483 connector = 'datagen',
1484 v1 = 123,
1485 v2 = 'with',
1486 v3 = false,
1487 v4 = '',
1488 ) FORMAT plain ENCODE json (a='1',b='2')
1489 ";
1490 assert_eq!(
1491 redact_sql(sql, keywords),
1492 "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1493 );
1494 }
1495}