1use std::collections::HashMap;
16use std::ffi::CStr;
17use std::io::{Error, ErrorKind, IoSlice, Result, Write};
18
19use anyhow::anyhow;
20use byteorder::{BigEndian, ByteOrder};
21use bytes::{Buf, BufMut, Bytes, BytesMut};
23use peekable::tokio::AsyncPeekable;
24use tokio::io::{AsyncRead, AsyncReadExt};
25
26use crate::error_or_notice::{ErrorOrNoticeMessage, Severity};
27use crate::pg_field_descriptor::PgFieldDescriptor;
28use crate::pg_response::StatementType;
29use crate::types::{Format, Row};
30
31#[derive(Debug)]
33pub enum FeMessage {
34 Ssl,
35 Gss,
36 Startup(FeStartupMessage),
37 Query(FeQueryMessage),
38 Parse(FeParseMessage),
39 Password(FePasswordMessage),
40 Describe(FeDescribeMessage),
41 Bind(FeBindMessage),
42 Execute(FeExecuteMessage),
43 Close(FeCloseMessage),
44 Sync,
45 CancelQuery(FeCancelMessage),
46 Terminate,
47 Flush,
48 HealthCheck,
50 ServerThrottle(ServerThrottleReason),
52}
53
54impl FeMessage {
55 pub fn get_sql(&self) -> Result<Option<&str>> {
56 match self {
57 FeMessage::Query(q) => Ok(Some(q.get_sql()?)),
58 _ => Ok(None),
59 }
60 }
61}
62
63#[derive(Debug)]
64pub enum ServerThrottleReason {
65 TooLargeMessage,
66 TooManyMemoryUsage,
67}
68
69#[derive(Debug)]
70pub struct FeStartupMessage {
71 pub config: HashMap<String, String>,
72}
73
74impl FeStartupMessage {
75 pub fn build_with_payload(payload: &[u8]) -> Result<Self> {
76 let config = match std::str::from_utf8(payload) {
77 Ok(v) => Ok(v.trim_end_matches('\0')),
78 Err(err) => Err(Error::new(
79 ErrorKind::InvalidInput,
80 anyhow!(err).context("Input end error"),
81 )),
82 }?;
83 let mut map = HashMap::new();
84 let config: Vec<&str> = config.split('\0').collect();
85 if config.len() % 2 == 1 {
86 return Err(Error::new(
87 ErrorKind::InvalidInput,
88 "Invalid input config: odd number of config pairs",
89 ));
90 }
91 config.chunks(2).for_each(|chunk| {
92 map.insert(chunk[0].to_owned(), chunk[1].to_owned());
93 });
94 Ok(FeStartupMessage { config: map })
95 }
96}
97
98#[derive(Debug)]
100pub struct FeQueryMessage {
101 pub sql_bytes: Bytes,
102}
103
104#[derive(Debug)]
105pub struct FeBindMessage {
106 pub param_format_codes: Vec<i16>,
107 pub result_format_codes: Vec<i16>,
108
109 pub params: Vec<Option<Bytes>>,
110 pub portal_name: Bytes,
111 pub statement_name: Bytes,
112}
113
114#[derive(Debug)]
115pub struct FeExecuteMessage {
116 pub portal_name: Bytes,
117 pub max_rows: i32,
118}
119
120#[derive(Debug)]
121pub struct FeParseMessage {
122 pub statement_name: Bytes,
123 pub sql_bytes: Bytes,
124 pub type_ids: Vec<i32>,
125}
126
127#[derive(Debug)]
128pub struct FePasswordMessage {
129 pub password: Bytes,
130}
131
132#[derive(Debug)]
133pub struct FeDescribeMessage {
134 pub kind: u8,
136 pub name: Bytes,
137}
138
139#[derive(Debug)]
140pub struct FeCloseMessage {
141 pub kind: u8,
142 pub name: Bytes,
143}
144
145#[derive(Debug)]
146pub struct FeCancelMessage {
147 pub target_process_id: i32,
148 pub target_secret_key: i32,
149}
150
151impl FeCancelMessage {
152 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
153 let target_process_id = buf.get_i32();
154 let target_secret_key = buf.get_i32();
155 Ok(FeMessage::CancelQuery(Self {
156 target_process_id,
157 target_secret_key,
158 }))
159 }
160}
161
162impl FeDescribeMessage {
163 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
164 let kind = buf.get_u8();
165 let name = read_null_terminated(&mut buf)?;
166
167 Ok(FeMessage::Describe(FeDescribeMessage { kind, name }))
168 }
169}
170
171impl FeBindMessage {
172 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
190 let portal_name = read_null_terminated(&mut buf)?;
191 let statement_name = read_null_terminated(&mut buf)?;
192
193 let len = buf.get_i16();
194 let param_format_codes = (0..len).map(|_| buf.get_i16()).collect();
195
196 let len = buf.get_i16();
198 let params = (0..len)
199 .map(|_| {
200 let val_len = buf.get_i32();
201 if val_len == -1 {
202 None
203 } else {
204 Some(buf.copy_to_bytes(val_len as usize))
205 }
206 })
207 .collect();
208
209 let len = buf.get_i16();
210 let result_format_codes = (0..len).map(|_| buf.get_i16()).collect();
211
212 Ok(FeMessage::Bind(FeBindMessage {
213 param_format_codes,
214 result_format_codes,
215 params,
216 portal_name,
217 statement_name,
218 }))
219 }
220}
221
222impl FeExecuteMessage {
223 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
224 let portal_name = read_null_terminated(&mut buf)?;
225 let max_rows = buf.get_i32();
226
227 Ok(FeMessage::Execute(FeExecuteMessage {
228 portal_name,
229 max_rows,
230 }))
231 }
232}
233
234impl FeParseMessage {
235 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
236 let statement_name = read_null_terminated(&mut buf)?;
237 let sql_bytes = read_null_terminated(&mut buf)?;
238 let nparams = buf.get_i16();
239
240 let type_ids: Vec<i32> = (0..nparams).map(|_| buf.get_i32()).collect();
241
242 Ok(FeMessage::Parse(FeParseMessage {
243 statement_name,
244 sql_bytes,
245 type_ids,
246 }))
247 }
248}
249
250impl FePasswordMessage {
251 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
252 let password = read_null_terminated(&mut buf)?;
253
254 Ok(FeMessage::Password(FePasswordMessage { password }))
255 }
256}
257
258impl FeQueryMessage {
259 pub fn get_sql(&self) -> Result<&str> {
260 match CStr::from_bytes_with_nul(&self.sql_bytes) {
261 Ok(cstr) => cstr.to_str().map_err(|err| {
262 Error::new(
263 ErrorKind::InvalidInput,
264 anyhow!(err).context("Invalid UTF-8 sequence"),
265 )
266 }),
267 Err(err) => Err(Error::new(
268 ErrorKind::InvalidInput,
269 anyhow!(err).context("Input end error"),
270 )),
271 }
272 }
273}
274
275impl FeCloseMessage {
276 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
277 let kind = buf.get_u8();
278 let name = read_null_terminated(&mut buf)?;
279 Ok(FeMessage::Close(FeCloseMessage { kind, name }))
280 }
281}
282
283#[derive(Clone)]
284pub struct FeMessageHeader {
285 pub tag: u8,
286 pub payload_len: i32,
287}
288
289impl FeMessage {
290 pub async fn read_header(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessageHeader> {
292 let tag = stream.read_u8().await?;
293 let len = stream.read_i32().await?;
294
295 let payload_len = len - 4;
296 Ok(FeMessageHeader { tag, payload_len })
297 }
298
299 pub async fn read_body(
301 stream: &mut (impl AsyncRead + Unpin),
302 header: FeMessageHeader,
303 ) -> Result<FeMessage> {
304 let FeMessageHeader { tag, payload_len } = header;
305 let mut payload: Vec<u8> = vec![0; payload_len as usize];
306 if payload_len > 0 {
307 stream.read_exact(&mut payload).await?;
308 }
309 let sql_bytes = Bytes::from(payload);
310 match tag {
311 b'Q' => Ok(FeMessage::Query(FeQueryMessage { sql_bytes })),
312 b'P' => FeParseMessage::parse(sql_bytes),
313 b'D' => FeDescribeMessage::parse(sql_bytes),
314 b'B' => FeBindMessage::parse(sql_bytes),
315 b'E' => FeExecuteMessage::parse(sql_bytes),
316 b'S' => Ok(FeMessage::Sync),
317 b'X' => Ok(FeMessage::Terminate),
318 b'C' => FeCloseMessage::parse(sql_bytes),
319 b'p' => FePasswordMessage::parse(sql_bytes),
320 b'H' => Ok(FeMessage::Flush),
321 _ => Err(std::io::Error::new(
322 ErrorKind::InvalidInput,
323 format!("Unsupported tag of regular message: {}", tag),
324 )),
325 }
326 }
327
328 pub async fn skip_body(
329 stream: &mut (impl AsyncRead + Unpin),
330 header: FeMessageHeader,
331 ) -> Result<()> {
332 let FeMessageHeader {
333 tag: _,
334 payload_len,
335 } = header;
336
337 if payload_len > 0 {
338 const BUF_SIZE: usize = 1024;
340 let mut buf: Vec<u8> = vec![0; BUF_SIZE];
341 for _ in 0..(payload_len as usize) / BUF_SIZE {
342 stream.read_exact(&mut buf).await?;
343 }
344 let remain = (payload_len as usize) % BUF_SIZE;
345 if remain > 0 {
346 buf.truncate(remain);
347 stream.read_exact(&mut buf).await?;
348 }
349 }
350 Ok(())
351 }
352}
353
354impl FeStartupMessage {
355 pub async fn read(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessage> {
357 let mut stream = AsyncPeekable::new(stream);
358
359 if let Err(err) = stream.peek_exact(&mut [0; 1]).await {
360 if err.kind() == ErrorKind::UnexpectedEof {
362 return Ok(FeMessage::HealthCheck);
363 } else {
364 return Err(err);
365 }
366 }
367
368 let len = stream.read_i32().await?;
369 let protocol_num = stream.read_i32().await?;
370 let payload_len = (len - 8) as usize;
371 if payload_len >= isize::MAX as usize {
372 return Err(std::io::Error::new(
373 ErrorKind::InvalidInput,
374 format!("Payload length has exceed usize::MAX {:?}", payload_len),
375 ));
376 }
377 let mut payload = vec![0; payload_len];
378 if payload_len > 0 {
379 stream.read_exact(&mut payload).await?;
380 }
381 match protocol_num {
382 196608 => Ok(FeMessage::Startup(FeStartupMessage::build_with_payload(
384 &payload,
385 )?)),
386 80877104 => Ok(FeMessage::Gss),
387 80877103 => Ok(FeMessage::Ssl),
388 80877102 => FeCancelMessage::parse(Bytes::from(payload)),
390 _ => Err(std::io::Error::new(
391 ErrorKind::InvalidInput,
392 format!(
393 "Unsupported protocol number in start up msg {:?}",
394 protocol_num
395 ),
396 )),
397 }
398 }
399}
400
401fn read_null_terminated(buf: &mut Bytes) -> Result<Bytes> {
403 let mut result = BytesMut::new();
404
405 loop {
406 if !buf.has_remaining() {
407 panic!("no null-terminator in string");
408 }
409
410 let byte = buf.get_u8();
411
412 if byte == 0 {
413 break;
414 }
415 result.put_u8(byte);
416 }
417 Ok(result.freeze())
418}
419
420#[derive(Debug, Clone, Copy)]
424pub enum BeMessage<'a> {
425 AuthenticationOk,
426 AuthenticationCleartextPassword,
427 AuthenticationMd5Password(&'a [u8; 4]),
428 CommandComplete(BeCommandCompleteMessage),
429 NoticeResponse(&'a str),
430 EncryptionResponseSsl,
432 EncryptionResponseGss,
433 EncryptionResponseNo,
434 EmptyQueryResponse,
435 ParseComplete,
436 BindComplete,
437 PortalSuspended,
438 ParameterDescription(&'a [i32]),
440 NoData,
441 DataRow(&'a Row),
442 ParameterStatus(BeParameterStatusMessage<'a>),
443 ReadyForQuery(TransactionStatus),
444 RowDescription(&'a [PgFieldDescriptor]),
445 ErrorResponse {
446 error: &'a (dyn std::error::Error + Send + Sync + 'static),
447 pretty: bool,
448 severity: Option<Severity>,
449 },
450 CloseComplete,
451
452 CopyOutResponse(usize),
454 CopyData(&'a Row),
455 CopyDone,
456
457 BackendKeyData((i32, i32)),
459}
460
461#[derive(Debug, Copy, Clone)]
462pub enum BeParameterStatusMessage<'a> {
463 ClientEncoding(&'a str),
464 StandardConformingString(&'a str),
465 ServerVersion(&'a str),
466 ApplicationName(&'a str),
467 TimeZone(&'a str),
468}
469
470#[derive(Debug, Copy, Clone)]
471pub struct BeCommandCompleteMessage {
472 pub stmt_type: StatementType,
473 pub rows_cnt: i32,
474}
475
476#[derive(Debug, Clone, Copy)]
477pub enum TransactionStatus {
478 Idle,
479 InTransaction,
480 InFailedTransaction,
481}
482
483impl BeMessage<'_> {
484 pub fn write(buf: &mut BytesMut, message: BeMessage<'_>) -> Result<()> {
486 match message {
487 BeMessage::AuthenticationOk => {
492 buf.put_u8(b'R');
493 buf.put_i32(8);
494 buf.put_i32(0);
495 }
496
497 BeMessage::AuthenticationCleartextPassword => {
502 buf.put_u8(b'R');
503 buf.put_i32(8);
504 buf.put_i32(3);
505 }
506
507 BeMessage::AuthenticationMd5Password(salt) => {
515 buf.put_u8(b'R');
516 buf.put_i32(12);
517 buf.put_i32(5);
518 buf.put_slice(&salt[..]);
519 }
520
521 BeMessage::ParameterStatus(param) => {
526 use BeParameterStatusMessage::*;
527 let [name, value] = match param {
528 ClientEncoding(val) => [b"client_encoding", val.as_bytes()],
529 StandardConformingString(val) => {
530 [b"standard_conforming_strings", val.as_bytes()]
531 }
532 ServerVersion(val) => [b"server_version", val.as_bytes()],
533 ApplicationName(val) => [b"application_name", val.as_bytes()],
534 TimeZone(val) => [b"TimeZone", val.as_bytes()],
536 };
537
538 let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
540 let mut buffer = vec![];
541 let cnt = buffer.write_vectored(iov).unwrap();
542
543 buf.put_u8(b'S');
544 write_body(buf, |stream| {
545 stream.put_slice(&buffer[..cnt]);
546 Ok(())
547 })
548 .unwrap();
549 }
550
551 BeMessage::CommandComplete(cmd) => {
556 let rows_cnt = cmd.rows_cnt;
557 let mut stmt_type = cmd.stmt_type;
558 let mut tag = "".to_owned();
559 stmt_type = match stmt_type {
560 StatementType::INSERT_RETURNING => StatementType::INSERT,
561 StatementType::DELETE_RETURNING => StatementType::DELETE,
562 StatementType::UPDATE_RETURNING => StatementType::UPDATE,
563 s => s,
564 };
565 tag.push_str(&stmt_type.to_string());
566 if stmt_type == StatementType::INSERT {
567 tag.push_str(" 0");
568 }
569 if stmt_type.is_command() {
570 tag.push(' ');
571 tag.push_str(&rows_cnt.to_string());
572 }
573 buf.put_u8(b'C');
574 write_body(buf, |buf| {
575 write_cstr(buf, tag.as_bytes())?;
576 Ok(())
577 })?;
578 }
579
580 BeMessage::NoticeResponse(notice) => {
587 buf.put_u8(b'N');
588 write_err_or_notice(buf, &ErrorOrNoticeMessage::notice(notice))?;
589 }
590
591 BeMessage::DataRow(vals) => {
600 buf.put_u8(b'D');
601 write_body(buf, |buf| {
602 buf.put_u16(vals.len() as u16); for val_opt in vals.values() {
604 if let Some(val) = val_opt {
605 buf.put_u32(val.len() as u32);
606 buf.put_slice(val);
607 } else {
608 buf.put_i32(-1);
609 }
610 }
611 Ok(())
612 })
613 .unwrap();
614 }
615
616 BeMessage::RowDescription(row_descs) => {
630 buf.put_u8(b'T');
631 write_body(buf, |buf| {
632 buf.put_i16(row_descs.len() as i16); for pg_field in row_descs {
634 write_cstr(buf, pg_field.get_name().as_bytes())?;
635 buf.put_i32(pg_field.get_table_oid()); buf.put_i16(pg_field.get_col_attr_num()); buf.put_i32(pg_field.get_type_oid());
638 buf.put_i16(pg_field.get_type_len());
639 buf.put_i32(pg_field.get_type_modifier()); buf.put_i16(pg_field.get_format_code()); }
642 Ok(())
643 })?;
644 }
645 BeMessage::ReadyForQuery(txn_status) => {
650 buf.put_u8(b'Z');
651 buf.put_i32(5);
652 buf.put_u8(match txn_status {
654 TransactionStatus::Idle => b'I',
655 TransactionStatus::InTransaction => b'T',
656 TransactionStatus::InFailedTransaction => b'E',
657 });
658 }
659
660 BeMessage::ParseComplete => {
661 buf.put_u8(b'1');
662 write_body(buf, |_| Ok(()))?;
663 }
664
665 BeMessage::BindComplete => {
666 buf.put_u8(b'2');
667 write_body(buf, |_| Ok(()))?;
668 }
669
670 BeMessage::CloseComplete => {
671 buf.put_u8(b'3');
672 write_body(buf, |_| Ok(()))?;
673 }
674
675 BeMessage::PortalSuspended => {
676 buf.put_u8(b's');
677 write_body(buf, |_| Ok(()))?;
678 }
679 BeMessage::ParameterDescription(para_descs) => {
684 buf.put_u8(b't');
685 write_body(buf, |buf| {
686 buf.put_i16(para_descs.len() as i16);
687 for oid in para_descs {
688 buf.put_i32(*oid);
689 }
690 Ok(())
691 })?;
692 }
693
694 BeMessage::NoData => {
695 buf.put_u8(b'n');
696 write_body(buf, |_| Ok(())).unwrap();
697 }
698
699 BeMessage::EncryptionResponseSsl => {
700 buf.put_u8(b'S');
701 }
702
703 BeMessage::EncryptionResponseGss => {
704 buf.put_u8(b'G');
705 }
706
707 BeMessage::EncryptionResponseNo => {
708 buf.put_u8(b'N');
709 }
710
711 BeMessage::EmptyQueryResponse => {
716 buf.put_u8(b'I');
717 buf.put_i32(4);
718 }
719
720 BeMessage::ErrorResponse {
721 error,
722 pretty,
723 severity,
724 } => {
725 buf.put_u8(b'E');
727 let error_message = match severity {
729 Some(severity) => {
730 ErrorOrNoticeMessage::error_with_severity(error, pretty, severity)
731 }
732 None => ErrorOrNoticeMessage::error(error, pretty),
733 };
734 write_err_or_notice(buf, &error_message)?;
735 }
736
737 BeMessage::BackendKeyData((process_id, secret_key)) => {
738 buf.put_u8(b'K');
739 write_body(buf, |buf| {
740 buf.put_i32(process_id);
741 buf.put_i32(secret_key);
742 Ok(())
743 })?;
744 }
745 BeMessage::CopyOutResponse(col_num) => {
746 buf.put_u8(b'H');
747 write_body(buf, |buf| {
748 buf.put_i8(Format::Text.to_i8());
749 buf.put_i16(col_num as _);
750 for _ in 0..col_num {
751 buf.put_i16(Format::Text.to_i8() as _);
752 }
753 Ok(())
754 })?;
755 }
756 BeMessage::CopyData(row) => {
757 buf.put_u8(b'd');
758 write_body(buf, |buf| {
760 fn write_str_bytes(
761 buf: &mut BytesMut,
762 str_bytes: &Option<Bytes>,
763 ) -> Result<()> {
764 let Some(str_bytes) = str_bytes else {
765 return Ok(());
766 };
767 let s = String::from_utf8_lossy(str_bytes);
768 for c in s.as_str().chars() {
769 match c {
772 '\t' => {
773 buf.put_slice(b"\\t");
774 }
775 '\n' => {
776 buf.put_slice(b"\\n");
777 }
778 '\r' => {
779 buf.put_slice(b"\\r");
780 }
781 '\\' => {
782 buf.put_slice(b"\\\\");
783 }
784 _ => {
785 std::fmt::Write::write_char(buf, c).map_err(|_| {
786 Error::other(anyhow!("failed to write_char [{c}]"))
787 })?;
788 }
789 }
790 }
791 Ok(())
792 }
793 match row.values() {
794 [] => {}
795 [first, rest @ ..] => {
796 write_str_bytes(buf, first)?;
797
798 for rest in rest {
799 buf.put_u8(b'\t');
800 write_str_bytes(buf, rest)?;
801 }
802 }
803 }
804 buf.put_u8(b'\n');
805 Ok(())
806 })?;
807 }
808 BeMessage::CopyDone => {
809 buf.put_u8(b'c');
810 write_body(buf, |_| Ok(()))?;
811 }
812 }
813
814 Ok(())
815 }
816}
817
818trait FromUsize: Sized {
820 fn from_usize(x: usize) -> Result<Self>;
821}
822
823macro_rules! from_usize {
824 ($t:ty) => {
825 impl FromUsize for $t {
826 #[inline]
827 fn from_usize(x: usize) -> Result<$t> {
828 if x > <$t>::MAX as usize {
829 Err(Error::new(ErrorKind::InvalidInput, "value too large to transmit").into())
830 } else {
831 Ok(x as $t)
832 }
833 }
834 }
835 };
836}
837
838from_usize!(i32);
839
840fn write_body<F>(buf: &mut BytesMut, f: F) -> Result<()>
844where
845 F: FnOnce(&mut BytesMut) -> Result<()>,
846{
847 let base = buf.len();
848 buf.extend_from_slice(&[0; 4]);
849
850 f(buf)?;
851
852 let size = i32::from_usize(buf.len() - base)?;
853 BigEndian::write_i32(&mut buf[base..], size);
854 Ok(())
855}
856
857fn write_cstr(buf: &mut BytesMut, s: &[u8]) -> Result<()> {
859 if s.contains(&0) {
860 return Err(Error::new(
861 ErrorKind::InvalidInput,
862 "string contains embedded null",
863 ));
864 }
865 buf.put_slice(s);
866 buf.put_u8(0);
867 Ok(())
868}
869
870fn write_err_or_notice(buf: &mut BytesMut, msg: &ErrorOrNoticeMessage<'_>) -> Result<()> {
872 write_body(buf, |buf| {
873 buf.put_u8(b'S'); write_cstr(buf, msg.severity.as_str().as_bytes())?;
875
876 buf.put_u8(b'C'); write_cstr(buf, msg.error_code.sqlstate().as_bytes())?;
878
879 buf.put_u8(b'M'); write_cstr(buf, msg.message.as_bytes())?;
881
882 buf.put_u8(0); Ok(())
884 })
885}
886
887#[cfg(test)]
888mod tests {
889 use bytes::Bytes;
890
891 use crate::pg_message::FeQueryMessage;
892
893 #[test]
894 fn test_get_sql() {
895 let fe = FeQueryMessage {
896 sql_bytes: Bytes::from(vec![255, 255, 255, 255, 255, 255, 0]),
897 };
898 assert!(fe.get_sql().is_err(), "{}", true);
899 let fe = FeQueryMessage {
900 sql_bytes: Bytes::from(vec![1, 2, 3, 4, 5, 6, 7, 8]),
901 };
902 assert!(fe.get_sql().is_err(), "{}", true);
903 }
904}