1use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
46use crate::net::AddressRef;
47use crate::pg_extended::ResultCache;
48use crate::pg_message::{
49 BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
50 FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
51 FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
52};
53use crate::pg_server::{Session, SessionManager, UserAuthenticator};
54use crate::types::Format;
55
56static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
59 LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
60 Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
61 _ => {
62 if cfg!(debug_assertions) {
63 65536
64 } else {
65 1024
66 }
67 }
68 });
69
70tokio::task_local! {
71 pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
73}
74
75pub struct PgProtocol<S, SM>
78where
79 SM: SessionManager,
80{
81 stream: PgStream<S>,
83 state: PgProtocolState,
85 is_terminate: bool,
87
88 session_mgr: Arc<SM>,
89 session: Option<Arc<SM::Session>>,
90
91 result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
92 unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
93 prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
94 unnamed_portal: Option<<SM::Session as Session>::Portal>,
95 portal_store: HashMap<String, <SM::Session as Session>::Portal>,
96 statement_portal_dependency: HashMap<String, Vec<String>>,
99
100 tls_context: Option<SslContext>,
103
104 tls_config: Option<TlsConfig>,
106
107 ignore_util_sync: bool,
110
111 peer_addr: AddressRef,
113
114 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
115 message_memory_manager: MessageMemoryManagerRef,
116}
117
118#[derive(Debug, Clone)]
120pub struct TlsConfig {
121 pub cert: String,
123 pub key: String,
125 pub enforce_ssl: bool,
127}
128
129impl TlsConfig {
130 pub fn new_default() -> Option<Self> {
131 let cert = std::env::var("RW_SSL_CERT").ok()?;
132 let key = std::env::var("RW_SSL_KEY").ok()?;
133 let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
134 tracing::info!(
135 "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
136 cert,
137 key,
138 enforce_ssl
139 );
140 Some(Self {
141 cert,
142 key,
143 enforce_ssl,
144 })
145 }
146}
147
148impl<S, SM> Drop for PgProtocol<S, SM>
149where
150 SM: SessionManager,
151{
152 fn drop(&mut self) {
153 if let Some(session) = &self.session {
154 self.session_mgr.end_session(session);
156 }
157 }
158}
159
160enum PgProtocolState {
162 Startup,
163 Regular,
164}
165
166pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
170 let without_null = if b.last() == Some(&0) {
171 &b[..b.len() - 1]
172 } else {
173 &b[..]
174 };
175 std::str::from_utf8(without_null)
176}
177
178fn record_sql_in_span(
180 sql: &str,
181 redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
182) -> String {
183 let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
184 && !keywords.is_empty()
185 {
186 redact_sql(sql, keywords)
187 } else {
188 sql.to_owned()
189 };
190 let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
191 tracing::Span::current().record("sql", tracing::field::display(&truncated));
192 truncated.to_string()
193}
194
195fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
197 match Parser::parse_sql(sql) {
198 Ok(sqls) => sqls
199 .into_iter()
200 .map(|sql| sql.to_redacted_string(keywords.clone()))
201 .join(";"),
202 Err(_) => sql.to_owned(),
203 }
204}
205
206#[derive(Clone)]
207pub struct ConnectionContext {
208 pub tls_config: Option<TlsConfig>,
209 pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
210 pub message_memory_manager: MessageMemoryManagerRef,
211}
212
213impl<S, SM> PgProtocol<S, SM>
214where
215 S: PgByteStream,
216 SM: SessionManager,
217{
218 pub fn new(
219 stream: S,
220 session_mgr: Arc<SM>,
221 peer_addr: AddressRef,
222 context: ConnectionContext,
223 ) -> Self {
224 let ConnectionContext {
225 tls_config,
226 redact_sql_option_keywords,
227 message_memory_manager,
228 } = context;
229 Self {
230 stream: PgStream::new(stream),
231 is_terminate: false,
232 state: PgProtocolState::Startup,
233 session_mgr,
234 session: None,
235 tls_context: tls_config
236 .as_ref()
237 .and_then(|e| build_ssl_ctx_from_config(e).ok()),
238 tls_config,
239 result_cache: Default::default(),
240 unnamed_prepare_statement: Default::default(),
241 prepare_statement_store: Default::default(),
242 unnamed_portal: Default::default(),
243 portal_store: Default::default(),
244 statement_portal_dependency: Default::default(),
245 ignore_util_sync: false,
246 peer_addr,
247 redact_sql_option_keywords,
248 message_memory_manager,
249 }
250 }
251
252 pub async fn run(&mut self) {
254 let mut notice_fut = None;
255
256 loop {
257 if notice_fut.is_none()
259 && let Some(session) = self.session.clone()
260 {
261 let mut stream = self.stream.clone();
262 notice_fut = Some(Box::pin(async move {
263 loop {
264 let notice = session.next_notice().await;
265 if let Err(e) = stream.write(BeMessage::NoticeResponse(¬ice)).await {
266 tracing::error!(error = %e.as_report(), notice, "failed to send notice");
267 }
268 }
269 }));
270 }
271
272 let process = std::pin::pin!(async {
274 let (msg, _memory_guard) = match self.read_message().await {
275 Ok(msg) => msg,
276 Err(e) => {
277 tracing::error!(error = %e.as_report(), "error when reading message");
278 return true; }
280 };
281 tracing::trace!(?msg, "received message");
282 self.process(msg).await
283 });
284
285 let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
286 tokio::select! {
287 _ = notice_fut => unreachable!(),
288 terminated = process => terminated,
289 }
290 } else {
291 process.await
292 };
293
294 if terminated {
295 break;
296 }
297 }
298 }
299
300 pub async fn process(&mut self, msg: FeMessage) -> bool {
302 self.do_process(msg).await.is_none() || self.is_terminate
303 }
304
305 fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
313 let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
314 return tracing::Span::none();
315 };
316
317 let mode = match msg {
318 FeMessage::Query(_) => "simple query",
319 FeMessage::Parse(_) => "extended query parse",
320 FeMessage::Execute(_) => "extended query execute",
321 _ => return tracing::Span::none(),
322 };
323
324 tracing::info_span!(
325 target: PGWIRE_ROOT_SPAN_TARGET,
326 "handle_query",
327 mode,
328 session_id,
329 sql = tracing::field::Empty, )
331 }
332
333 async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
337 let span = self.root_span_for_msg(&msg);
338 let weak_session = self
339 .session
340 .as_ref()
341 .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
342
343 let fut = Box::pin(self.do_process_inner(msg));
348
349 let fut = async move {
351 if let Some(session) = weak_session {
352 CURRENT_SESSION.scope(session, fut).await
353 } else {
354 fut.await
355 }
356 };
357
358 let fut = async move {
360 AssertUnwindSafe(fut)
361 .rw_catch_unwind()
362 .await
363 .unwrap_or_else(|payload| {
364 Err(PsqlError::Panic(
365 panic_message::panic_message(&payload).to_owned(),
366 ))
367 })
368 };
369
370 let fut = async move {
372 let period = *SLOW_QUERY_LOG_PERIOD;
373 let mut fut = std::pin::pin!(fut);
374 let mut elapsed = Duration::ZERO;
375
376 loop {
378 match tokio::time::timeout(period, &mut fut).await {
379 Ok(result) => break result,
380 Err(_) => {
381 elapsed += period;
382 tracing::info!(
383 target: PGWIRE_SLOW_QUERY_LOG,
384 elapsed = %format_args!("{}ms", elapsed.as_millis()),
385 "slow query"
386 );
387 }
388 }
389 }
390 };
391
392 let fut = async move {
394 if !tracing::Span::current().is_none() {
395 tracing::info!(
396 target: PGWIRE_QUERY_LOG,
397 status = "started",
398 );
399 }
400
401 let start = Instant::now();
402 let result = fut.await;
403 let elapsed = start.elapsed();
404
405 if let Err(error) = &result {
409 if cfg!(debug_assertions) && !Deployment::current().is_ci() {
410 tracing::error!(error = ?error.as_report(), "error when process message");
416 } else {
417 tracing::error!(error = %error.as_report(), "error when process message");
418 }
419 }
420
421 if !tracing::Span::current().is_none() {
424 tracing::info!(
425 target: PGWIRE_QUERY_LOG,
426 status = if result.is_ok() { "ok" } else { "err" },
427 time = %format_args!("{}ms", elapsed.as_millis()),
428 );
429 }
430
431 result
432 };
433
434 let fut = fut.instrument(span);
436
437 match fut.await {
439 Ok(()) => Some(()),
440 Err(e) => {
441 match e {
442 PsqlError::IoError(io_err) => {
443 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
444 return None;
445 }
446 }
447
448 PsqlError::SslError(_) => {
449 return None;
452 }
453
454 PsqlError::StartupError(_) | PsqlError::PasswordError => {
455 self.stream
456 .write_no_flush(BeMessage::ErrorResponse(&e))
457 .ok()?;
458 let _ = self.stream.flush().await;
459 return None;
460 }
461
462 PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
463 self.stream
464 .write_no_flush(BeMessage::ErrorResponse(&e))
465 .ok()?;
466 self.ready_for_query().ok()?;
467 }
468
469 PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
470 self.stream
471 .write_no_flush(BeMessage::ErrorResponse(&e))
472 .ok()?;
473 let _ = self.stream.flush().await;
474
475 return None;
480 }
481
482 PsqlError::Uncategorized(_)
483 | PsqlError::ExtendedPrepareError(_)
484 | PsqlError::ExtendedExecuteError(_) => {
485 self.stream
486 .write_no_flush(BeMessage::ErrorResponse(&e))
487 .ok()?;
488 }
489 }
490 let _ = self.stream.flush().await;
491 Some(())
492 }
493 }
494 }
495
496 async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
497 if self.ignore_util_sync {
499 if let FeMessage::Sync = msg {
500 } else {
501 tracing::trace!("ignore message {:?} until sync.", msg);
502 return Ok(());
503 }
504 }
505
506 match msg {
507 FeMessage::Gss => self.process_gss_msg().await?,
508 FeMessage::Ssl => self.process_ssl_msg().await?,
509 FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
510 FeMessage::Password(msg) => self.process_password_msg(msg).await?,
511 FeMessage::Query(query_msg) => {
512 let sql = Arc::from(query_msg.get_sql()?);
513 drop(query_msg);
515 self.process_query_msg(sql).await?
516 }
517 FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
518 FeMessage::Terminate => self.process_terminate(),
519 FeMessage::Parse(m) => {
520 if let Err(err) = self.process_parse_msg(m).await {
521 self.ignore_util_sync = true;
522 return Err(err);
523 }
524 }
525 FeMessage::Bind(m) => {
526 if let Err(err) = self.process_bind_msg(m) {
527 self.ignore_util_sync = true;
528 return Err(err);
529 }
530 }
531 FeMessage::Execute(m) => {
532 if let Err(err) = self.process_execute_msg(m).await {
533 self.ignore_util_sync = true;
534 return Err(err);
535 }
536 }
537 FeMessage::Describe(m) => {
538 if let Err(err) = self.process_describe_msg(m) {
539 self.ignore_util_sync = true;
540 return Err(err);
541 }
542 }
543 FeMessage::Sync => {
544 self.ignore_util_sync = false;
545 self.ready_for_query()?
546 }
547 FeMessage::Close(m) => {
548 if let Err(err) = self.process_close_msg(m) {
549 self.ignore_util_sync = true;
550 return Err(err);
551 }
552 }
553 FeMessage::Flush => {
554 if let Err(err) = self.stream.flush().await {
555 self.ignore_util_sync = true;
556 return Err(err.into());
557 }
558 }
559 FeMessage::HealthCheck => self.process_health_check(),
560 FeMessage::ServerThrottle(reason) => match reason {
561 ServerThrottleReason::TooLargeMessage => {
562 return Err(PsqlError::ServerThrottle(format!(
563 "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
564 self.message_memory_manager.max_filter_bytes
565 )));
566 }
567 ServerThrottleReason::TooManyMemoryUsage => {
568 return Err(PsqlError::ServerThrottle(format!(
569 "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
570 self.message_memory_manager.max_running_bytes
571 )));
572 }
573 },
574 }
575 self.stream.flush().await?;
576 Ok(())
577 }
578
579 pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
580 match self.state {
581 PgProtocolState::Startup => self
582 .stream
583 .read_startup()
584 .await
585 .map(|message: FeMessage| (message, None)),
586 PgProtocolState::Regular => {
587 self.stream.read_header().await?;
588 let guard = if let Some(ref header) = self.stream.read_header {
589 let payload_len = std::cmp::max(header.payload_len, 0) as u64;
590 let (reason, guard) = self.message_memory_manager.add(payload_len);
591 if let Some(reason) = reason {
592 drop(guard);
594 self.stream.skip_body().await?;
595 return Ok((FeMessage::ServerThrottle(reason), None));
596 }
597 guard
598 } else {
599 None
600 };
601 let message = self.stream.read_body().await?;
602 Ok((message, guard))
603 }
604 }
605 }
606
607 fn ready_for_query(&mut self) -> io::Result<()> {
609 self.stream.write_no_flush(BeMessage::ReadyForQuery(
610 self.session
611 .as_ref()
612 .map(|s| s.transaction_status())
613 .unwrap_or(TransactionStatus::Idle),
614 ))
615 }
616
617 async fn process_gss_msg(&mut self) -> PsqlResult<()> {
618 self.stream.write(BeMessage::EncryptionResponseNo).await?;
620 Ok(())
621 }
622
623 async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
624 if let Some(context) = self.tls_context.as_ref() {
625 self.stream.write(BeMessage::EncryptionResponseSsl).await?;
628 self.stream.upgrade_to_ssl(context).await?;
629 } else {
630 self.stream.write(BeMessage::EncryptionResponseNo).await?;
632 }
633
634 Ok(())
635 }
636
637 async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
638 if let Some(ref tls_config) = self.tls_config
640 && tls_config.enforce_ssl
641 && !self.stream.is_ssl_connection().await
642 {
643 return Err(PsqlError::StartupError(
644 "SSL connection is required but not established".into(),
645 ));
646 }
647
648 let db_name = msg
649 .config
650 .get("database")
651 .cloned()
652 .unwrap_or_else(|| "dev".to_owned());
653 let user_name = msg
654 .config
655 .get("user")
656 .cloned()
657 .unwrap_or_else(|| "root".to_owned());
658
659 let session = self
660 .session_mgr
661 .connect(&db_name, &user_name, self.peer_addr.clone())
662 .map_err(PsqlError::StartupError)?;
663
664 let application_name = msg.config.get("application_name");
665 if let Some(application_name) = application_name {
666 session
667 .set_config("application_name", application_name.clone())
668 .map_err(PsqlError::StartupError)?;
669 }
670
671 match session.user_authenticator() {
672 UserAuthenticator::None => {
673 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
674
675 self.stream
678 .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
679
680 self.stream.write_no_flush(BeMessage::ParameterStatus(
681 BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
682 ))?;
683 self.stream
684 .write_parameter_status_msg_no_flush(&ParameterStatus {
685 application_name: application_name.cloned(),
686 })?;
687 self.ready_for_query()?;
688 }
689 UserAuthenticator::ClearText(_)
690 | UserAuthenticator::OAuth { .. }
691 | UserAuthenticator::Ldap(..) => {
692 self.stream
693 .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
694 }
695 UserAuthenticator::Md5WithSalt { salt, .. } => {
696 self.stream
697 .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
698 }
699 }
700
701 self.session = Some(session);
702 self.state = PgProtocolState::Regular;
703 Ok(())
704 }
705
706 async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
707 let session = self.session.as_ref().unwrap();
708 let authenticator = session.user_authenticator();
709 authenticator.authenticate(&msg.password).await?;
710 self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
711 self.stream.write_no_flush(BeMessage::ParameterStatus(
712 BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
713 ))?;
714 self.stream
715 .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
716 self.ready_for_query()?;
717 self.state = PgProtocolState::Regular;
718 Ok(())
719 }
720
721 fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
722 let session_id = (m.target_process_id, m.target_secret_key);
723 tracing::trace!("cancel query in session: {:?}", session_id);
724 self.session_mgr.cancel_queries_in_session(session_id);
725 self.session_mgr.cancel_creating_jobs_in_session(session_id);
726 self.is_terminate = true;
727 Ok(())
728 }
729
730 async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
731 let truncated_sql = record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
732 let session = self.session.clone().unwrap();
733
734 session.check_idle_in_transaction_timeout()?;
735 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
737 self.inner_process_query_msg(sql, session.clone()).await
738 }
739
740 async fn inner_process_query_msg(
741 &mut self,
742 sql: Arc<str>,
743 session: Arc<SM::Session>,
744 ) -> PsqlResult<()> {
745 let stmts =
747 Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
748 drop(sql);
750 if stmts.is_empty() {
751 self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
752 }
753
754 for stmt in stmts {
756 self.inner_process_query_msg_one_stmt(stmt, session.clone())
757 .await?;
758 }
759 self.ready_for_query()?;
762 Ok(())
763 }
764
765 async fn inner_process_query_msg_one_stmt(
766 &mut self,
767 stmt: Statement,
768 session: Arc<SM::Session>,
769 ) -> PsqlResult<()> {
770 let session = session.clone();
771
772 let res = session.clone().run_one_query(stmt, Format::Text).await;
774
775 while let Some(notice) = session.next_notice().now_or_never() {
777 self.stream
778 .write_no_flush(BeMessage::NoticeResponse(¬ice))?;
779 }
780
781 let mut res = res.map_err(PsqlError::SimpleQueryError)?;
782
783 for notice in res.notices() {
784 self.stream
785 .write_no_flush(BeMessage::NoticeResponse(notice))?;
786 }
787
788 let status = res.status();
789 if let Some(ref application_name) = status.application_name {
790 self.stream.write_no_flush(BeMessage::ParameterStatus(
791 BeParameterStatusMessage::ApplicationName(application_name),
792 ))?;
793 }
794
795 if res.is_copy_query_to_stdout() {
796 self.stream
797 .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
798 let mut count = 0;
799 while let Some(row_set) = res.values_stream().next().await {
800 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
801 for row in row_set {
802 self.stream.write_no_flush(BeMessage::CopyData(&row))?;
803 count += 1;
804 }
805 }
806
807 self.stream.write_no_flush(BeMessage::CopyDone)?;
808
809 res.run_callback().await?;
811
812 self.stream
813 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
814 stmt_type: res.stmt_type(),
815 rows_cnt: count,
816 }))?;
817 } else if res.is_query() {
818 self.stream
819 .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
820
821 let mut rows_cnt = 0;
822
823 while let Some(row_set) = res.values_stream().next().await {
824 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
825 for row in row_set {
826 self.stream.write_no_flush(BeMessage::DataRow(&row))?;
827 rows_cnt += 1;
828 }
829 }
830
831 res.run_callback().await?;
833
834 self.stream
835 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
836 stmt_type: res.stmt_type(),
837 rows_cnt,
838 }))?;
839 } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
840 let first_row_set = res.values_stream().next().await;
841 let first_row_set = match first_row_set {
842 None => {
843 return Err(PsqlError::Uncategorized(
844 anyhow::anyhow!("no affected rows in output").into(),
845 ));
846 }
847 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
848 };
849 let affected_rows_str = first_row_set[0].values()[0]
850 .as_ref()
851 .expect("compute node should return affected rows in output");
852
853 assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
854 let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
855 .unwrap()
856 .parse()
857 .unwrap_or_default();
858
859 res.run_callback().await?;
861
862 self.stream
863 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
864 stmt_type: res.stmt_type(),
865 rows_cnt: affected_rows_cnt,
866 }))?;
867 } else {
868 res.run_callback().await?;
870
871 self.stream
872 .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
873 stmt_type: res.stmt_type(),
874 rows_cnt: 0,
875 }))?;
876 }
877
878 Ok(())
879 }
880
881 fn process_terminate(&mut self) {
882 self.is_terminate = true;
883 }
884
885 fn process_health_check(&mut self) {
886 tracing::debug!("health check");
887 self.is_terminate = true;
888 }
889
890 async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
891 let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
892 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
893 let session = self.session.clone().unwrap();
894 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
895 let type_ids = std::mem::take(&mut msg.type_ids);
896 drop(msg);
898 self.inner_process_parse_msg(session, sql, statement_name, type_ids)
899 .await?;
900 Ok(())
901 }
902
903 async fn inner_process_parse_msg(
904 &mut self,
905 session: Arc<SM::Session>,
906 sql: Arc<str>,
907 statement_name: String,
908 type_ids: Vec<i32>,
909 ) -> PsqlResult<()> {
910 if statement_name.is_empty() {
911 self.unnamed_prepare_statement.take();
914 } else if self.prepare_statement_store.contains_key(&statement_name) {
915 return Err(PsqlError::ExtendedPrepareError(
916 "Duplicated statement name".into(),
917 ));
918 }
919
920 let stmt = {
921 let stmts = Parser::parse_sql(&sql)
922 .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
923 drop(sql);
924 if stmts.len() > 1 {
925 return Err(PsqlError::ExtendedPrepareError(
926 "Only one statement is allowed in extended query mode".into(),
927 ));
928 }
929
930 stmts.into_iter().next()
931 };
932
933 let param_types: Vec<Option<DataType>> = type_ids
934 .iter()
935 .map(|&id| {
936 if id == 0 {
939 Ok(None)
940 } else {
941 DataType::from_oid(id)
942 .map(Some)
943 .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
944 }
945 })
946 .try_collect()?;
947
948 let prepare_statement = session
949 .parse(stmt, param_types)
950 .await
951 .map_err(PsqlError::ExtendedPrepareError)?;
952
953 if statement_name.is_empty() {
954 self.unnamed_prepare_statement.replace(prepare_statement);
955 } else {
956 self.prepare_statement_store
957 .insert(statement_name.clone(), prepare_statement);
958 }
959
960 self.statement_portal_dependency
961 .entry(statement_name)
962 .or_default()
963 .clear();
964
965 self.stream.write_no_flush(BeMessage::ParseComplete)?;
966 Ok(())
967 }
968
969 fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
970 let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
971 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
972 let session = self.session.clone().unwrap();
973
974 if self.portal_store.contains_key(&portal_name) {
975 return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
976 }
977
978 let prepare_statement = self.get_statement(&statement_name)?;
979
980 let result_formats = msg
981 .result_format_codes
982 .iter()
983 .map(|&format_code| Format::from_i16(format_code))
984 .try_collect()?;
985 let param_formats = msg
986 .param_format_codes
987 .iter()
988 .map(|&format_code| Format::from_i16(format_code))
989 .try_collect()?;
990
991 let portal = session
992 .bind(prepare_statement, msg.params, param_formats, result_formats)
993 .map_err(PsqlError::Uncategorized)?;
994
995 if portal_name.is_empty() {
996 self.result_cache.remove(&portal_name);
997 self.unnamed_portal.replace(portal);
998 } else {
999 assert!(
1000 !self.result_cache.contains_key(&portal_name),
1001 "Named portal never can be overridden."
1002 );
1003 self.portal_store.insert(portal_name.clone(), portal);
1004 }
1005
1006 self.statement_portal_dependency
1007 .get_mut(&statement_name)
1008 .unwrap()
1009 .push(portal_name);
1010
1011 self.stream.write_no_flush(BeMessage::BindComplete)?;
1012 Ok(())
1013 }
1014
1015 async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1016 let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1017 let row_max = msg.max_rows as usize;
1018 drop(msg);
1019 let session = self.session.clone().unwrap();
1020
1021 match self.result_cache.remove(&portal_name) {
1022 Some(mut result_cache) => {
1023 assert!(self.portal_store.contains_key(&portal_name));
1024
1025 let is_consume_completed =
1026 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1027
1028 if !is_consume_completed {
1029 self.result_cache.insert(portal_name, result_cache);
1030 }
1031 }
1032 _ => {
1033 let portal = self.get_portal(&portal_name)?;
1034 let sql = format!("{}", portal);
1035 let truncated_sql =
1036 record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
1037 drop(sql);
1038
1039 session.check_idle_in_transaction_timeout()?;
1040 let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1042 let result = session.clone().execute(portal).await;
1043
1044 let pg_response = result.map_err(PsqlError::ExtendedExecuteError)?;
1045 let mut result_cache = ResultCache::new(pg_response);
1046 let is_consume_completed =
1047 result_cache.consume::<S>(row_max, &mut self.stream).await?;
1048 if !is_consume_completed {
1049 self.result_cache.insert(portal_name, result_cache);
1050 }
1051 }
1052 }
1053
1054 Ok(())
1055 }
1056
1057 fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1058 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1059 let session = self.session.clone().unwrap();
1060 assert!(msg.kind == b'S' || msg.kind == b'P');
1064 if msg.kind == b'S' {
1065 let prepare_statement = self.get_statement(&name)?;
1066
1067 let (param_types, row_descriptions) = self
1068 .session
1069 .clone()
1070 .unwrap()
1071 .describe_statement(prepare_statement)
1072 .map_err(PsqlError::Uncategorized)?;
1073 self.stream.write_no_flush(BeMessage::ParameterDescription(
1074 ¶m_types.iter().map(|t| t.to_oid()).collect_vec(),
1075 ))?;
1076
1077 if row_descriptions.is_empty() {
1078 self.stream.write_no_flush(BeMessage::NoData)?;
1081 } else {
1082 self.stream
1083 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1084 }
1085 } else if msg.kind == b'P' {
1086 let portal = self.get_portal(&name)?;
1087
1088 let row_descriptions = session
1089 .describe_portal(portal)
1090 .map_err(PsqlError::Uncategorized)?;
1091
1092 if row_descriptions.is_empty() {
1093 self.stream.write_no_flush(BeMessage::NoData)?;
1096 } else {
1097 self.stream
1098 .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1099 }
1100 }
1101 Ok(())
1102 }
1103
1104 fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1105 let name = cstr_to_str(&msg.name).unwrap().to_owned();
1106 assert!(msg.kind == b'S' || msg.kind == b'P');
1107 if msg.kind == b'S' {
1108 if name.is_empty() {
1109 self.unnamed_prepare_statement = None;
1110 } else {
1111 self.prepare_statement_store.remove(&name);
1112 }
1113 for portal_name in self
1114 .statement_portal_dependency
1115 .remove(&name)
1116 .unwrap_or_default()
1117 {
1118 self.remove_portal(&portal_name);
1119 }
1120 } else if msg.kind == b'P' {
1121 self.remove_portal(&name);
1122 }
1123 self.stream.write_no_flush(BeMessage::CloseComplete)?;
1124 Ok(())
1125 }
1126
1127 fn remove_portal(&mut self, portal_name: &str) {
1128 if portal_name.is_empty() {
1129 self.unnamed_portal = None;
1130 } else {
1131 self.portal_store.remove(portal_name);
1132 }
1133 self.result_cache.remove(portal_name);
1134 }
1135
1136 fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1137 if portal_name.is_empty() {
1138 Ok(self
1139 .unnamed_portal
1140 .as_ref()
1141 .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1142 .clone())
1143 } else {
1144 Ok(self
1145 .portal_store
1146 .get(portal_name)
1147 .ok_or_else(|| {
1148 PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1149 })?
1150 .clone())
1151 }
1152 }
1153
1154 fn get_statement(
1155 &self,
1156 statement_name: &str,
1157 ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1158 if statement_name.is_empty() {
1159 Ok(self
1160 .unnamed_prepare_statement
1161 .as_ref()
1162 .ok_or_else(|| {
1163 PsqlError::Uncategorized("unnamed prepare statement not found".into())
1164 })?
1165 .clone())
1166 } else {
1167 Ok(self
1168 .prepare_statement_store
1169 .get(statement_name)
1170 .ok_or_else(|| {
1171 PsqlError::Uncategorized(
1172 format!("Prepare statement {} not found", statement_name).into(),
1173 )
1174 })?
1175 .clone())
1176 }
1177 }
1178}
1179
1180enum PgStreamInner<S> {
1181 Placeholder,
1183 Unencrypted(S),
1185 Ssl(SslStream<S>),
1187}
1188
1189pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1191impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1192
1193pub struct PgStream<S> {
1198 stream: Arc<Mutex<PgStreamInner<S>>>,
1200 write_buf: BytesMut,
1202 read_header: Option<FeMessageHeader>,
1203}
1204
1205impl<S> PgStream<S> {
1206 pub fn new(stream: S) -> Self {
1208 const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1209
1210 Self {
1211 stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1212 write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1213 read_header: None,
1214 }
1215 }
1216
1217 async fn is_ssl_connection(&self) -> bool {
1219 let stream = self.stream.lock().await;
1220 matches!(*stream, PgStreamInner::Ssl(_))
1221 }
1222}
1223
1224impl<S> Clone for PgStream<S> {
1225 fn clone(&self) -> Self {
1226 Self {
1227 stream: Arc::clone(&self.stream),
1228 write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1229 read_header: self.read_header.clone(),
1230 }
1231 }
1232}
1233
1234#[derive(Debug, Default, Clone)]
1251pub struct ParameterStatus {
1252 pub application_name: Option<String>,
1253}
1254
1255impl<S> PgStream<S>
1256where
1257 S: PgByteStream,
1258{
1259 async fn read_startup(&mut self) -> io::Result<FeMessage> {
1260 let mut stream = self.stream.lock().await;
1261 match &mut *stream {
1262 PgStreamInner::Placeholder => unreachable!(),
1263 PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1264 PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1265 }
1266 }
1267
1268 async fn read_header(&mut self) -> io::Result<()> {
1269 let mut stream = self.stream.lock().await;
1270 match &mut *stream {
1271 PgStreamInner::Placeholder => unreachable!(),
1272 PgStreamInner::Unencrypted(stream) => {
1273 self.read_header = Some(FeMessage::read_header(stream).await?);
1274 Ok(())
1275 }
1276 PgStreamInner::Ssl(ssl_stream) => {
1277 self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1278 Ok(())
1279 }
1280 }
1281 }
1282
1283 async fn read_body(&mut self) -> io::Result<FeMessage> {
1284 let mut stream = self.stream.lock().await;
1285 let header = self
1286 .read_header
1287 .take()
1288 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1289 match &mut *stream {
1290 PgStreamInner::Placeholder => unreachable!(),
1291 PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1292 PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1293 }
1294 }
1295
1296 async fn skip_body(&mut self) -> io::Result<()> {
1297 let mut stream = self.stream.lock().await;
1298 let header = self
1299 .read_header
1300 .take()
1301 .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1302 match &mut *stream {
1303 PgStreamInner::Placeholder => unreachable!(),
1304 PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1305 PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1306 }
1307 }
1308
1309 fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1310 self.write_no_flush(BeMessage::ParameterStatus(
1311 BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1312 ))?;
1313 self.write_no_flush(BeMessage::ParameterStatus(
1314 BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1315 ))?;
1316 self.write_no_flush(BeMessage::ParameterStatus(
1317 BeParameterStatusMessage::ServerVersion(PG_VERSION),
1318 ))?;
1319 if let Some(application_name) = &status.application_name {
1320 self.write_no_flush(BeMessage::ParameterStatus(
1321 BeParameterStatusMessage::ApplicationName(application_name),
1322 ))?;
1323 }
1324 Ok(())
1325 }
1326
1327 pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1328 BeMessage::write(&mut self.write_buf, message)
1329 }
1330
1331 async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1332 self.write_no_flush(message)?;
1333 self.flush().await?;
1334 Ok(())
1335 }
1336
1337 async fn flush(&mut self) -> io::Result<()> {
1338 let mut stream = self.stream.lock().await;
1339 match &mut *stream {
1340 PgStreamInner::Placeholder => unreachable!(),
1341 PgStreamInner::Unencrypted(stream) => {
1342 stream.write_all(&self.write_buf).await?;
1343 stream.flush().await?;
1344 }
1345 PgStreamInner::Ssl(ssl_stream) => {
1346 ssl_stream.write_all(&self.write_buf).await?;
1347 ssl_stream.flush().await?;
1348 }
1349 }
1350 self.write_buf.clear();
1351 Ok(())
1352 }
1353}
1354
1355impl<S> PgStream<S>
1356where
1357 S: PgByteStream,
1358{
1359 async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1361 let mut stream = self.stream.lock().await;
1362
1363 match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1364 PgStreamInner::Unencrypted(unencrypted_stream) => {
1365 let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1366 let mut ssl_stream =
1367 tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1368
1369 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1370 tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1371 let _ = ssl_stream.shutdown().await;
1372 return Err(e.into());
1373 }
1374
1375 *stream = PgStreamInner::Ssl(ssl_stream);
1376 }
1377 PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1378 PgStreamInner::Placeholder => unreachable!(),
1379 }
1380
1381 Ok(())
1382 }
1383}
1384
1385fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1386 let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1387
1388 let key_path = &tls_config.key;
1389 let cert_path = &tls_config.cert;
1390
1391 acceptor
1394 .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1395 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1396 acceptor
1397 .set_ca_file(cert_path)
1398 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1399 acceptor
1400 .set_certificate_chain_file(cert_path)
1401 .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1402 let acceptor = acceptor.build();
1403
1404 Ok(acceptor.into_context())
1405}
1406
1407pub mod truncated_fmt {
1408 use std::fmt::*;
1409
1410 struct TruncatedFormatter<'a, 'b> {
1411 remaining: usize,
1412 finished: bool,
1413 f: &'a mut Formatter<'b>,
1414 }
1415 impl Write for TruncatedFormatter<'_, '_> {
1416 fn write_str(&mut self, s: &str) -> Result {
1417 if self.finished {
1418 return Ok(());
1419 }
1420
1421 if self.remaining < s.len() {
1422 let actual = s.floor_char_boundary(self.remaining);
1423 self.f.write_str(&s[0..actual])?;
1424 self.remaining -= actual;
1425 self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1426 self.finished = true; } else {
1428 self.f.write_str(s)?;
1429 self.remaining -= s.len();
1430 }
1431 Ok(())
1432 }
1433 }
1434
1435 pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1436
1437 impl<T> Debug for TruncatedFmt<'_, T>
1438 where
1439 T: Debug,
1440 {
1441 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1442 TruncatedFormatter {
1443 remaining: self.1,
1444 finished: false,
1445 f,
1446 }
1447 .write_fmt(format_args!("{:?}", self.0))
1448 }
1449 }
1450
1451 impl<T> Display for TruncatedFmt<'_, T>
1452 where
1453 T: Display,
1454 {
1455 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1456 TruncatedFormatter {
1457 remaining: self.1,
1458 finished: false,
1459 f,
1460 }
1461 .write_fmt(format_args!("{}", self.0))
1462 }
1463 }
1464
1465 #[cfg(test)]
1466 mod tests {
1467 use super::*;
1468
1469 #[test]
1470 fn test_trunc_utf8() {
1471 assert_eq!(
1472 format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1473 "select '...(truncated,14)",
1474 );
1475 }
1476 }
1477}
1478
1479#[cfg(test)]
1480mod tests {
1481 use std::collections::HashSet;
1482
1483 use super::*;
1484
1485 #[test]
1486 fn test_redact_parsable_sql() {
1487 let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1488 let sql = r"
1489 create source temp (k bigint, v varchar) with (
1490 connector = 'datagen',
1491 v1 = 123,
1492 v2 = 'with',
1493 v3 = false,
1494 v4 = '',
1495 ) FORMAT plain ENCODE json (a='1',b='2')
1496 ";
1497 assert_eq!(
1498 redact_sql(sql, keywords),
1499 "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1500 );
1501 }
1502}