1use std::collections::HashMap;
16use std::ffi::CStr;
17use std::io::{Error, ErrorKind, IoSlice, Result, Write};
18
19use anyhow::anyhow;
20use byteorder::{BigEndian, ByteOrder};
21use bytes::{Buf, BufMut, Bytes, BytesMut};
23use peekable::tokio::AsyncPeekable;
24use tokio::io::{AsyncRead, AsyncReadExt};
25
26use crate::error_or_notice::ErrorOrNoticeMessage;
27use crate::pg_field_descriptor::PgFieldDescriptor;
28use crate::pg_response::StatementType;
29use crate::types::{Format, Row};
30
31#[derive(Debug)]
33pub enum FeMessage {
34 Ssl,
35 Gss,
36 Startup(FeStartupMessage),
37 Query(FeQueryMessage),
38 Parse(FeParseMessage),
39 Password(FePasswordMessage),
40 Describe(FeDescribeMessage),
41 Bind(FeBindMessage),
42 Execute(FeExecuteMessage),
43 Close(FeCloseMessage),
44 Sync,
45 CancelQuery(FeCancelMessage),
46 Terminate,
47 Flush,
48 HealthCheck,
50 ServerThrottle(ServerThrottleReason),
52}
53
54#[derive(Debug)]
55pub enum ServerThrottleReason {
56 TooLargeMessage,
57 TooManyMemoryUsage,
58}
59
60#[derive(Debug)]
61pub struct FeStartupMessage {
62 pub config: HashMap<String, String>,
63}
64
65impl FeStartupMessage {
66 pub fn build_with_payload(payload: &[u8]) -> Result<Self> {
67 let config = match std::str::from_utf8(payload) {
68 Ok(v) => Ok(v.trim_end_matches('\0')),
69 Err(err) => Err(Error::new(
70 ErrorKind::InvalidInput,
71 anyhow!(err).context("Input end error"),
72 )),
73 }?;
74 let mut map = HashMap::new();
75 let config: Vec<&str> = config.split('\0').collect();
76 if config.len() % 2 == 1 {
77 return Err(Error::new(
78 ErrorKind::InvalidInput,
79 "Invalid input config: odd number of config pairs",
80 ));
81 }
82 config.chunks(2).for_each(|chunk| {
83 map.insert(chunk[0].to_owned(), chunk[1].to_owned());
84 });
85 Ok(FeStartupMessage { config: map })
86 }
87}
88
89#[derive(Debug)]
91pub struct FeQueryMessage {
92 pub sql_bytes: Bytes,
93}
94
95#[derive(Debug)]
96pub struct FeBindMessage {
97 pub param_format_codes: Vec<i16>,
98 pub result_format_codes: Vec<i16>,
99
100 pub params: Vec<Option<Bytes>>,
101 pub portal_name: Bytes,
102 pub statement_name: Bytes,
103}
104
105#[derive(Debug)]
106pub struct FeExecuteMessage {
107 pub portal_name: Bytes,
108 pub max_rows: i32,
109}
110
111#[derive(Debug)]
112pub struct FeParseMessage {
113 pub statement_name: Bytes,
114 pub sql_bytes: Bytes,
115 pub type_ids: Vec<i32>,
116}
117
118#[derive(Debug)]
119pub struct FePasswordMessage {
120 pub password: Bytes,
121}
122
123#[derive(Debug)]
124pub struct FeDescribeMessage {
125 pub kind: u8,
127 pub name: Bytes,
128}
129
130#[derive(Debug)]
131pub struct FeCloseMessage {
132 pub kind: u8,
133 pub name: Bytes,
134}
135
136#[derive(Debug)]
137pub struct FeCancelMessage {
138 pub target_process_id: i32,
139 pub target_secret_key: i32,
140}
141
142impl FeCancelMessage {
143 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
144 let target_process_id = buf.get_i32();
145 let target_secret_key = buf.get_i32();
146 Ok(FeMessage::CancelQuery(Self {
147 target_process_id,
148 target_secret_key,
149 }))
150 }
151}
152
153impl FeDescribeMessage {
154 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
155 let kind = buf.get_u8();
156 let name = read_null_terminated(&mut buf)?;
157
158 Ok(FeMessage::Describe(FeDescribeMessage { kind, name }))
159 }
160}
161
162impl FeBindMessage {
163 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
181 let portal_name = read_null_terminated(&mut buf)?;
182 let statement_name = read_null_terminated(&mut buf)?;
183
184 let len = buf.get_i16();
185 let param_format_codes = (0..len).map(|_| buf.get_i16()).collect();
186
187 let len = buf.get_i16();
189 let params = (0..len)
190 .map(|_| {
191 let val_len = buf.get_i32();
192 if val_len == -1 {
193 None
194 } else {
195 Some(buf.copy_to_bytes(val_len as usize))
196 }
197 })
198 .collect();
199
200 let len = buf.get_i16();
201 let result_format_codes = (0..len).map(|_| buf.get_i16()).collect();
202
203 Ok(FeMessage::Bind(FeBindMessage {
204 param_format_codes,
205 result_format_codes,
206 params,
207 portal_name,
208 statement_name,
209 }))
210 }
211}
212
213impl FeExecuteMessage {
214 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
215 let portal_name = read_null_terminated(&mut buf)?;
216 let max_rows = buf.get_i32();
217
218 Ok(FeMessage::Execute(FeExecuteMessage {
219 portal_name,
220 max_rows,
221 }))
222 }
223}
224
225impl FeParseMessage {
226 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
227 let statement_name = read_null_terminated(&mut buf)?;
228 let sql_bytes = read_null_terminated(&mut buf)?;
229 let nparams = buf.get_i16();
230
231 let type_ids: Vec<i32> = (0..nparams).map(|_| buf.get_i32()).collect();
232
233 Ok(FeMessage::Parse(FeParseMessage {
234 statement_name,
235 sql_bytes,
236 type_ids,
237 }))
238 }
239}
240
241impl FePasswordMessage {
242 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
243 let password = read_null_terminated(&mut buf)?;
244
245 Ok(FeMessage::Password(FePasswordMessage { password }))
246 }
247}
248
249impl FeQueryMessage {
250 pub fn get_sql(&self) -> Result<&str> {
251 match CStr::from_bytes_with_nul(&self.sql_bytes) {
252 Ok(cstr) => cstr.to_str().map_err(|err| {
253 Error::new(
254 ErrorKind::InvalidInput,
255 anyhow!(err).context("Invalid UTF-8 sequence"),
256 )
257 }),
258 Err(err) => Err(Error::new(
259 ErrorKind::InvalidInput,
260 anyhow!(err).context("Input end error"),
261 )),
262 }
263 }
264}
265
266impl FeCloseMessage {
267 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
268 let kind = buf.get_u8();
269 let name = read_null_terminated(&mut buf)?;
270 Ok(FeMessage::Close(FeCloseMessage { kind, name }))
271 }
272}
273
274#[derive(Clone)]
275pub struct FeMessageHeader {
276 pub tag: u8,
277 pub payload_len: i32,
278}
279
280impl FeMessage {
281 pub async fn read_header(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessageHeader> {
283 let tag = stream.read_u8().await?;
284 let len = stream.read_i32().await?;
285
286 let payload_len = len - 4;
287 Ok(FeMessageHeader { tag, payload_len })
288 }
289
290 pub async fn read_body(
292 stream: &mut (impl AsyncRead + Unpin),
293 header: FeMessageHeader,
294 ) -> Result<FeMessage> {
295 let FeMessageHeader { tag, payload_len } = header;
296 let mut payload: Vec<u8> = vec![0; payload_len as usize];
297 if payload_len > 0 {
298 stream.read_exact(&mut payload).await?;
299 }
300 let sql_bytes = Bytes::from(payload);
301 match tag {
302 b'Q' => Ok(FeMessage::Query(FeQueryMessage { sql_bytes })),
303 b'P' => FeParseMessage::parse(sql_bytes),
304 b'D' => FeDescribeMessage::parse(sql_bytes),
305 b'B' => FeBindMessage::parse(sql_bytes),
306 b'E' => FeExecuteMessage::parse(sql_bytes),
307 b'S' => Ok(FeMessage::Sync),
308 b'X' => Ok(FeMessage::Terminate),
309 b'C' => FeCloseMessage::parse(sql_bytes),
310 b'p' => FePasswordMessage::parse(sql_bytes),
311 b'H' => Ok(FeMessage::Flush),
312 _ => Err(std::io::Error::new(
313 ErrorKind::InvalidInput,
314 format!("Unsupported tag of regular message: {}", tag),
315 )),
316 }
317 }
318
319 pub async fn skip_body(
320 stream: &mut (impl AsyncRead + Unpin),
321 header: FeMessageHeader,
322 ) -> Result<()> {
323 let FeMessageHeader {
324 tag: _,
325 payload_len,
326 } = header;
327
328 if payload_len > 0 {
329 const BUF_SIZE: usize = 1024;
331 let mut buf: Vec<u8> = vec![0; BUF_SIZE];
332 for _ in 0..(payload_len as usize) / BUF_SIZE {
333 stream.read_exact(&mut buf).await?;
334 }
335 let remain = (payload_len as usize) % BUF_SIZE;
336 if remain > 0 {
337 buf.truncate(remain);
338 stream.read_exact(&mut buf).await?;
339 }
340 }
341 Ok(())
342 }
343}
344
345impl FeStartupMessage {
346 pub async fn read(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessage> {
348 let mut stream = AsyncPeekable::new(stream);
349
350 if let Err(err) = stream.peek_exact(&mut [0; 1]).await {
351 if err.kind() == ErrorKind::UnexpectedEof {
353 return Ok(FeMessage::HealthCheck);
354 } else {
355 return Err(err);
356 }
357 }
358
359 let len = stream.read_i32().await?;
360 let protocol_num = stream.read_i32().await?;
361 let payload_len = (len - 8) as usize;
362 if payload_len >= isize::MAX as usize {
363 return Err(std::io::Error::new(
364 ErrorKind::InvalidInput,
365 format!("Payload length has exceed usize::MAX {:?}", payload_len),
366 ));
367 }
368 let mut payload = vec![0; payload_len];
369 if payload_len > 0 {
370 stream.read_exact(&mut payload).await?;
371 }
372 match protocol_num {
373 196608 => Ok(FeMessage::Startup(FeStartupMessage::build_with_payload(
375 &payload,
376 )?)),
377 80877104 => Ok(FeMessage::Gss),
378 80877103 => Ok(FeMessage::Ssl),
379 80877102 => FeCancelMessage::parse(Bytes::from(payload)),
381 _ => Err(std::io::Error::new(
382 ErrorKind::InvalidInput,
383 format!(
384 "Unsupported protocol number in start up msg {:?}",
385 protocol_num
386 ),
387 )),
388 }
389 }
390}
391
392fn read_null_terminated(buf: &mut Bytes) -> Result<Bytes> {
394 let mut result = BytesMut::new();
395
396 loop {
397 if !buf.has_remaining() {
398 panic!("no null-terminator in string");
399 }
400
401 let byte = buf.get_u8();
402
403 if byte == 0 {
404 break;
405 }
406 result.put_u8(byte);
407 }
408 Ok(result.freeze())
409}
410
411#[derive(Debug, Clone, Copy)]
415pub enum BeMessage<'a> {
416 AuthenticationOk,
417 AuthenticationCleartextPassword,
418 AuthenticationMd5Password(&'a [u8; 4]),
419 CommandComplete(BeCommandCompleteMessage),
420 NoticeResponse(&'a str),
421 EncryptionResponseSsl,
423 EncryptionResponseGss,
424 EncryptionResponseNo,
425 EmptyQueryResponse,
426 ParseComplete,
427 BindComplete,
428 PortalSuspended,
429 ParameterDescription(&'a [i32]),
431 NoData,
432 DataRow(&'a Row),
433 ParameterStatus(BeParameterStatusMessage<'a>),
434 ReadyForQuery(TransactionStatus),
435 RowDescription(&'a [PgFieldDescriptor]),
436 ErrorResponse(&'a (dyn std::error::Error + Send + Sync + 'static)),
437 CloseComplete,
438
439 CopyOutResponse(usize),
441 CopyData(&'a Row),
442 CopyDone,
443
444 BackendKeyData((i32, i32)),
446}
447
448#[derive(Debug, Copy, Clone)]
449pub enum BeParameterStatusMessage<'a> {
450 ClientEncoding(&'a str),
451 StandardConformingString(&'a str),
452 ServerVersion(&'a str),
453 ApplicationName(&'a str),
454 TimeZone(&'a str),
455}
456
457#[derive(Debug, Copy, Clone)]
458pub struct BeCommandCompleteMessage {
459 pub stmt_type: StatementType,
460 pub rows_cnt: i32,
461}
462
463#[derive(Debug, Clone, Copy)]
464pub enum TransactionStatus {
465 Idle,
466 InTransaction,
467 InFailedTransaction,
468}
469
470impl BeMessage<'_> {
471 pub fn write(buf: &mut BytesMut, message: BeMessage<'_>) -> Result<()> {
473 match message {
474 BeMessage::AuthenticationOk => {
479 buf.put_u8(b'R');
480 buf.put_i32(8);
481 buf.put_i32(0);
482 }
483
484 BeMessage::AuthenticationCleartextPassword => {
489 buf.put_u8(b'R');
490 buf.put_i32(8);
491 buf.put_i32(3);
492 }
493
494 BeMessage::AuthenticationMd5Password(salt) => {
502 buf.put_u8(b'R');
503 buf.put_i32(12);
504 buf.put_i32(5);
505 buf.put_slice(&salt[..]);
506 }
507
508 BeMessage::ParameterStatus(param) => {
513 use BeParameterStatusMessage::*;
514 let [name, value] = match param {
515 ClientEncoding(val) => [b"client_encoding", val.as_bytes()],
516 StandardConformingString(val) => {
517 [b"standard_conforming_strings", val.as_bytes()]
518 }
519 ServerVersion(val) => [b"server_version", val.as_bytes()],
520 ApplicationName(val) => [b"application_name", val.as_bytes()],
521 TimeZone(val) => [b"TimeZone", val.as_bytes()],
523 };
524
525 let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
527 let mut buffer = vec![];
528 let cnt = buffer.write_vectored(iov).unwrap();
529
530 buf.put_u8(b'S');
531 write_body(buf, |stream| {
532 stream.put_slice(&buffer[..cnt]);
533 Ok(())
534 })
535 .unwrap();
536 }
537
538 BeMessage::CommandComplete(cmd) => {
543 let rows_cnt = cmd.rows_cnt;
544 let mut stmt_type = cmd.stmt_type;
545 let mut tag = "".to_owned();
546 stmt_type = match stmt_type {
547 StatementType::INSERT_RETURNING => StatementType::INSERT,
548 StatementType::DELETE_RETURNING => StatementType::DELETE,
549 StatementType::UPDATE_RETURNING => StatementType::UPDATE,
550 s => s,
551 };
552 tag.push_str(&stmt_type.to_string());
553 if stmt_type == StatementType::INSERT {
554 tag.push_str(" 0");
555 }
556 if stmt_type.is_command() {
557 tag.push(' ');
558 tag.push_str(&rows_cnt.to_string());
559 }
560 buf.put_u8(b'C');
561 write_body(buf, |buf| {
562 write_cstr(buf, tag.as_bytes())?;
563 Ok(())
564 })?;
565 }
566
567 BeMessage::NoticeResponse(notice) => {
574 buf.put_u8(b'N');
575 write_err_or_notice(buf, &ErrorOrNoticeMessage::notice(notice))?;
576 }
577
578 BeMessage::DataRow(vals) => {
587 buf.put_u8(b'D');
588 write_body(buf, |buf| {
589 buf.put_u16(vals.len() as u16); for val_opt in vals.values() {
591 if let Some(val) = val_opt {
592 buf.put_u32(val.len() as u32);
593 buf.put_slice(val);
594 } else {
595 buf.put_i32(-1);
596 }
597 }
598 Ok(())
599 })
600 .unwrap();
601 }
602
603 BeMessage::RowDescription(row_descs) => {
617 buf.put_u8(b'T');
618 write_body(buf, |buf| {
619 buf.put_i16(row_descs.len() as i16); for pg_field in row_descs {
621 write_cstr(buf, pg_field.get_name().as_bytes())?;
622 buf.put_i32(pg_field.get_table_oid()); buf.put_i16(pg_field.get_col_attr_num()); buf.put_i32(pg_field.get_type_oid());
625 buf.put_i16(pg_field.get_type_len());
626 buf.put_i32(pg_field.get_type_modifier()); buf.put_i16(pg_field.get_format_code()); }
629 Ok(())
630 })?;
631 }
632 BeMessage::ReadyForQuery(txn_status) => {
637 buf.put_u8(b'Z');
638 buf.put_i32(5);
639 buf.put_u8(match txn_status {
641 TransactionStatus::Idle => b'I',
642 TransactionStatus::InTransaction => b'T',
643 TransactionStatus::InFailedTransaction => b'E',
644 });
645 }
646
647 BeMessage::ParseComplete => {
648 buf.put_u8(b'1');
649 write_body(buf, |_| Ok(()))?;
650 }
651
652 BeMessage::BindComplete => {
653 buf.put_u8(b'2');
654 write_body(buf, |_| Ok(()))?;
655 }
656
657 BeMessage::CloseComplete => {
658 buf.put_u8(b'3');
659 write_body(buf, |_| Ok(()))?;
660 }
661
662 BeMessage::PortalSuspended => {
663 buf.put_u8(b's');
664 write_body(buf, |_| Ok(()))?;
665 }
666 BeMessage::ParameterDescription(para_descs) => {
671 buf.put_u8(b't');
672 write_body(buf, |buf| {
673 buf.put_i16(para_descs.len() as i16);
674 for oid in para_descs {
675 buf.put_i32(*oid);
676 }
677 Ok(())
678 })?;
679 }
680
681 BeMessage::NoData => {
682 buf.put_u8(b'n');
683 write_body(buf, |_| Ok(())).unwrap();
684 }
685
686 BeMessage::EncryptionResponseSsl => {
687 buf.put_u8(b'S');
688 }
689
690 BeMessage::EncryptionResponseGss => {
691 buf.put_u8(b'G');
692 }
693
694 BeMessage::EncryptionResponseNo => {
695 buf.put_u8(b'N');
696 }
697
698 BeMessage::EmptyQueryResponse => {
703 buf.put_u8(b'I');
704 buf.put_i32(4);
705 }
706
707 BeMessage::ErrorResponse(error) => {
708 buf.put_u8(b'E');
710 write_err_or_notice(buf, &ErrorOrNoticeMessage::error(error))?;
712 }
713
714 BeMessage::BackendKeyData((process_id, secret_key)) => {
715 buf.put_u8(b'K');
716 write_body(buf, |buf| {
717 buf.put_i32(process_id);
718 buf.put_i32(secret_key);
719 Ok(())
720 })?;
721 }
722 BeMessage::CopyOutResponse(col_num) => {
723 buf.put_u8(b'H');
724 write_body(buf, |buf| {
725 buf.put_i8(Format::Text.to_i8());
726 buf.put_i16(col_num as _);
727 for _ in 0..col_num {
728 buf.put_i16(Format::Text.to_i8() as _);
729 }
730 Ok(())
731 })?;
732 }
733 BeMessage::CopyData(row) => {
734 buf.put_u8(b'd');
735 write_body(buf, |buf| {
737 fn write_str_bytes(
738 buf: &mut BytesMut,
739 str_bytes: &Option<Bytes>,
740 ) -> Result<()> {
741 let Some(str_bytes) = str_bytes else {
742 return Ok(());
743 };
744 let s = String::from_utf8_lossy(str_bytes);
745 for c in s.as_str().chars() {
746 match c {
749 '\t' => {
750 buf.put_slice(b"\\t");
751 }
752 '\n' => {
753 buf.put_slice(b"\\n");
754 }
755 '\r' => {
756 buf.put_slice(b"\\r");
757 }
758 '\\' => {
759 buf.put_slice(b"\\\\");
760 }
761 _ => {
762 std::fmt::Write::write_char(buf, c).map_err(|_| {
763 Error::other(anyhow!("failed to write_char [{c}]"))
764 })?;
765 }
766 }
767 }
768 Ok(())
769 }
770 match row.values() {
771 [] => {}
772 [first, rest @ ..] => {
773 write_str_bytes(buf, first)?;
774
775 for rest in rest {
776 buf.put_u8(b'\t');
777 write_str_bytes(buf, rest)?;
778 }
779 }
780 }
781 buf.put_u8(b'\n');
782 Ok(())
783 })?;
784 }
785 BeMessage::CopyDone => {
786 buf.put_u8(b'c');
787 write_body(buf, |_| Ok(()))?;
788 }
789 }
790
791 Ok(())
792 }
793}
794
795trait FromUsize: Sized {
797 fn from_usize(x: usize) -> Result<Self>;
798}
799
800macro_rules! from_usize {
801 ($t:ty) => {
802 impl FromUsize for $t {
803 #[inline]
804 fn from_usize(x: usize) -> Result<$t> {
805 if x > <$t>::MAX as usize {
806 Err(Error::new(ErrorKind::InvalidInput, "value too large to transmit").into())
807 } else {
808 Ok(x as $t)
809 }
810 }
811 }
812 };
813}
814
815from_usize!(i32);
816
817fn write_body<F>(buf: &mut BytesMut, f: F) -> Result<()>
821where
822 F: FnOnce(&mut BytesMut) -> Result<()>,
823{
824 let base = buf.len();
825 buf.extend_from_slice(&[0; 4]);
826
827 f(buf)?;
828
829 let size = i32::from_usize(buf.len() - base)?;
830 BigEndian::write_i32(&mut buf[base..], size);
831 Ok(())
832}
833
834fn write_cstr(buf: &mut BytesMut, s: &[u8]) -> Result<()> {
836 if s.contains(&0) {
837 return Err(Error::new(
838 ErrorKind::InvalidInput,
839 "string contains embedded null",
840 ));
841 }
842 buf.put_slice(s);
843 buf.put_u8(0);
844 Ok(())
845}
846
847fn write_err_or_notice(buf: &mut BytesMut, msg: &ErrorOrNoticeMessage<'_>) -> Result<()> {
849 write_body(buf, |buf| {
850 buf.put_u8(b'S'); write_cstr(buf, msg.severity.as_str().as_bytes())?;
852
853 buf.put_u8(b'C'); write_cstr(buf, msg.error_code.sqlstate().as_bytes())?;
855
856 buf.put_u8(b'M'); write_cstr(buf, msg.message.as_bytes())?;
858
859 buf.put_u8(0); Ok(())
861 })
862}
863
864#[cfg(test)]
865mod tests {
866 use bytes::Bytes;
867
868 use crate::pg_message::FeQueryMessage;
869
870 #[test]
871 fn test_get_sql() {
872 let fe = FeQueryMessage {
873 sql_bytes: Bytes::from(vec![255, 255, 255, 255, 255, 255, 0]),
874 };
875 assert!(fe.get_sql().is_err(), "{}", true);
876 let fe = FeQueryMessage {
877 sql_bytes: Bytes::from(vec![1, 2, 3, 4, 5, 6, 7, 8]),
878 };
879 assert!(fe.get_sql().is_err(), "{}", true);
880 }
881}