1use std::collections::HashMap;
16use std::ffi::CStr;
17use std::io::{Error, ErrorKind, IoSlice, Result, Write};
18
19use anyhow::anyhow;
20use byteorder::{BigEndian, ByteOrder};
21use bytes::{Buf, BufMut, Bytes, BytesMut};
23use peekable::tokio::AsyncPeekable;
24use tokio::io::{AsyncRead, AsyncReadExt};
25
26use crate::error_or_notice::ErrorOrNoticeMessage;
27use crate::pg_field_descriptor::PgFieldDescriptor;
28use crate::pg_response::StatementType;
29use crate::types::{Format, Row};
30
31#[derive(Debug)]
33pub enum FeMessage {
34 Ssl,
35 Gss,
36 Startup(FeStartupMessage),
37 Query(FeQueryMessage),
38 Parse(FeParseMessage),
39 Password(FePasswordMessage),
40 Describe(FeDescribeMessage),
41 Bind(FeBindMessage),
42 Execute(FeExecuteMessage),
43 Close(FeCloseMessage),
44 Sync,
45 CancelQuery(FeCancelMessage),
46 Terminate,
47 Flush,
48 HealthCheck,
50 ServerThrottle(ServerThrottleReason),
52}
53
54impl FeMessage {
55 pub fn get_sql(&self) -> Result<Option<&str>> {
56 match self {
57 FeMessage::Query(q) => Ok(Some(q.get_sql()?)),
58 _ => Ok(None),
59 }
60 }
61}
62
63#[derive(Debug)]
64pub enum ServerThrottleReason {
65 TooLargeMessage,
66 TooManyMemoryUsage,
67}
68
69#[derive(Debug)]
70pub struct FeStartupMessage {
71 pub config: HashMap<String, String>,
72}
73
74impl FeStartupMessage {
75 pub fn build_with_payload(payload: &[u8]) -> Result<Self> {
76 let config = match std::str::from_utf8(payload) {
77 Ok(v) => Ok(v.trim_end_matches('\0')),
78 Err(err) => Err(Error::new(
79 ErrorKind::InvalidInput,
80 anyhow!(err).context("Input end error"),
81 )),
82 }?;
83 let mut map = HashMap::new();
84 let config: Vec<&str> = config.split('\0').collect();
85 if config.len() % 2 == 1 {
86 return Err(Error::new(
87 ErrorKind::InvalidInput,
88 "Invalid input config: odd number of config pairs",
89 ));
90 }
91 config.chunks(2).for_each(|chunk| {
92 map.insert(chunk[0].to_owned(), chunk[1].to_owned());
93 });
94 Ok(FeStartupMessage { config: map })
95 }
96}
97
98#[derive(Debug)]
100pub struct FeQueryMessage {
101 pub sql_bytes: Bytes,
102}
103
104#[derive(Debug)]
105pub struct FeBindMessage {
106 pub param_format_codes: Vec<i16>,
107 pub result_format_codes: Vec<i16>,
108
109 pub params: Vec<Option<Bytes>>,
110 pub portal_name: Bytes,
111 pub statement_name: Bytes,
112}
113
114#[derive(Debug)]
115pub struct FeExecuteMessage {
116 pub portal_name: Bytes,
117 pub max_rows: i32,
118}
119
120#[derive(Debug)]
121pub struct FeParseMessage {
122 pub statement_name: Bytes,
123 pub sql_bytes: Bytes,
124 pub type_ids: Vec<i32>,
125}
126
127#[derive(Debug)]
128pub struct FePasswordMessage {
129 pub password: Bytes,
130}
131
132#[derive(Debug)]
133pub struct FeDescribeMessage {
134 pub kind: u8,
136 pub name: Bytes,
137}
138
139#[derive(Debug)]
140pub struct FeCloseMessage {
141 pub kind: u8,
142 pub name: Bytes,
143}
144
145#[derive(Debug)]
146pub struct FeCancelMessage {
147 pub target_process_id: i32,
148 pub target_secret_key: i32,
149}
150
151impl FeCancelMessage {
152 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
153 let target_process_id = buf.get_i32();
154 let target_secret_key = buf.get_i32();
155 Ok(FeMessage::CancelQuery(Self {
156 target_process_id,
157 target_secret_key,
158 }))
159 }
160}
161
162impl FeDescribeMessage {
163 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
164 let kind = buf.get_u8();
165 let name = read_null_terminated(&mut buf)?;
166
167 Ok(FeMessage::Describe(FeDescribeMessage { kind, name }))
168 }
169}
170
171impl FeBindMessage {
172 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
190 let portal_name = read_null_terminated(&mut buf)?;
191 let statement_name = read_null_terminated(&mut buf)?;
192
193 let len = buf.get_i16();
194 let param_format_codes = (0..len).map(|_| buf.get_i16()).collect();
195
196 let len = buf.get_i16();
198 let params = (0..len)
199 .map(|_| {
200 let val_len = buf.get_i32();
201 if val_len == -1 {
202 None
203 } else {
204 Some(buf.copy_to_bytes(val_len as usize))
205 }
206 })
207 .collect();
208
209 let len = buf.get_i16();
210 let result_format_codes = (0..len).map(|_| buf.get_i16()).collect();
211
212 Ok(FeMessage::Bind(FeBindMessage {
213 param_format_codes,
214 result_format_codes,
215 params,
216 portal_name,
217 statement_name,
218 }))
219 }
220}
221
222impl FeExecuteMessage {
223 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
224 let portal_name = read_null_terminated(&mut buf)?;
225 let max_rows = buf.get_i32();
226
227 Ok(FeMessage::Execute(FeExecuteMessage {
228 portal_name,
229 max_rows,
230 }))
231 }
232}
233
234impl FeParseMessage {
235 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
236 let statement_name = read_null_terminated(&mut buf)?;
237 let sql_bytes = read_null_terminated(&mut buf)?;
238 let nparams = buf.get_i16();
239
240 let type_ids: Vec<i32> = (0..nparams).map(|_| buf.get_i32()).collect();
241
242 Ok(FeMessage::Parse(FeParseMessage {
243 statement_name,
244 sql_bytes,
245 type_ids,
246 }))
247 }
248}
249
250impl FePasswordMessage {
251 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
252 let password = read_null_terminated(&mut buf)?;
253
254 Ok(FeMessage::Password(FePasswordMessage { password }))
255 }
256}
257
258impl FeQueryMessage {
259 pub fn get_sql(&self) -> Result<&str> {
260 match CStr::from_bytes_with_nul(&self.sql_bytes) {
261 Ok(cstr) => cstr.to_str().map_err(|err| {
262 Error::new(
263 ErrorKind::InvalidInput,
264 anyhow!(err).context("Invalid UTF-8 sequence"),
265 )
266 }),
267 Err(err) => Err(Error::new(
268 ErrorKind::InvalidInput,
269 anyhow!(err).context("Input end error"),
270 )),
271 }
272 }
273}
274
275impl FeCloseMessage {
276 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
277 let kind = buf.get_u8();
278 let name = read_null_terminated(&mut buf)?;
279 Ok(FeMessage::Close(FeCloseMessage { kind, name }))
280 }
281}
282
283#[derive(Clone)]
284pub struct FeMessageHeader {
285 pub tag: u8,
286 pub payload_len: i32,
287}
288
289impl FeMessage {
290 pub async fn read_header(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessageHeader> {
292 let tag = stream.read_u8().await?;
293 let len = stream.read_i32().await?;
294
295 let payload_len = len - 4;
296 Ok(FeMessageHeader { tag, payload_len })
297 }
298
299 pub async fn read_body(
301 stream: &mut (impl AsyncRead + Unpin),
302 header: FeMessageHeader,
303 ) -> Result<FeMessage> {
304 let FeMessageHeader { tag, payload_len } = header;
305 let mut payload: Vec<u8> = vec![0; payload_len as usize];
306 if payload_len > 0 {
307 stream.read_exact(&mut payload).await?;
308 }
309 let sql_bytes = Bytes::from(payload);
310 match tag {
311 b'Q' => Ok(FeMessage::Query(FeQueryMessage { sql_bytes })),
312 b'P' => FeParseMessage::parse(sql_bytes),
313 b'D' => FeDescribeMessage::parse(sql_bytes),
314 b'B' => FeBindMessage::parse(sql_bytes),
315 b'E' => FeExecuteMessage::parse(sql_bytes),
316 b'S' => Ok(FeMessage::Sync),
317 b'X' => Ok(FeMessage::Terminate),
318 b'C' => FeCloseMessage::parse(sql_bytes),
319 b'p' => FePasswordMessage::parse(sql_bytes),
320 b'H' => Ok(FeMessage::Flush),
321 _ => Err(std::io::Error::new(
322 ErrorKind::InvalidInput,
323 format!("Unsupported tag of regular message: {}", tag),
324 )),
325 }
326 }
327
328 pub async fn skip_body(
329 stream: &mut (impl AsyncRead + Unpin),
330 header: FeMessageHeader,
331 ) -> Result<()> {
332 let FeMessageHeader {
333 tag: _,
334 payload_len,
335 } = header;
336
337 if payload_len > 0 {
338 const BUF_SIZE: usize = 1024;
340 let mut buf: Vec<u8> = vec![0; BUF_SIZE];
341 for _ in 0..(payload_len as usize) / BUF_SIZE {
342 stream.read_exact(&mut buf).await?;
343 }
344 let remain = (payload_len as usize) % BUF_SIZE;
345 if remain > 0 {
346 buf.truncate(remain);
347 stream.read_exact(&mut buf).await?;
348 }
349 }
350 Ok(())
351 }
352}
353
354impl FeStartupMessage {
355 pub async fn read(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessage> {
357 let mut stream = AsyncPeekable::new(stream);
358
359 if let Err(err) = stream.peek_exact(&mut [0; 1]).await {
360 if err.kind() == ErrorKind::UnexpectedEof {
362 return Ok(FeMessage::HealthCheck);
363 } else {
364 return Err(err);
365 }
366 }
367
368 let len = stream.read_i32().await?;
369 let protocol_num = stream.read_i32().await?;
370 let payload_len = (len - 8) as usize;
371 if payload_len >= isize::MAX as usize {
372 return Err(std::io::Error::new(
373 ErrorKind::InvalidInput,
374 format!("Payload length has exceed usize::MAX {:?}", payload_len),
375 ));
376 }
377 let mut payload = vec![0; payload_len];
378 if payload_len > 0 {
379 stream.read_exact(&mut payload).await?;
380 }
381 match protocol_num {
382 196608 => Ok(FeMessage::Startup(FeStartupMessage::build_with_payload(
384 &payload,
385 )?)),
386 80877104 => Ok(FeMessage::Gss),
387 80877103 => Ok(FeMessage::Ssl),
388 80877102 => FeCancelMessage::parse(Bytes::from(payload)),
390 _ => Err(std::io::Error::new(
391 ErrorKind::InvalidInput,
392 format!(
393 "Unsupported protocol number in start up msg {:?}",
394 protocol_num
395 ),
396 )),
397 }
398 }
399}
400
401fn read_null_terminated(buf: &mut Bytes) -> Result<Bytes> {
403 let mut result = BytesMut::new();
404
405 loop {
406 if !buf.has_remaining() {
407 panic!("no null-terminator in string");
408 }
409
410 let byte = buf.get_u8();
411
412 if byte == 0 {
413 break;
414 }
415 result.put_u8(byte);
416 }
417 Ok(result.freeze())
418}
419
420#[derive(Debug, Clone, Copy)]
424pub enum BeMessage<'a> {
425 AuthenticationOk,
426 AuthenticationCleartextPassword,
427 AuthenticationMd5Password(&'a [u8; 4]),
428 CommandComplete(BeCommandCompleteMessage),
429 NoticeResponse(&'a str),
430 EncryptionResponseSsl,
432 EncryptionResponseGss,
433 EncryptionResponseNo,
434 EmptyQueryResponse,
435 ParseComplete,
436 BindComplete,
437 PortalSuspended,
438 ParameterDescription(&'a [i32]),
440 NoData,
441 DataRow(&'a Row),
442 ParameterStatus(BeParameterStatusMessage<'a>),
443 ReadyForQuery(TransactionStatus),
444 RowDescription(&'a [PgFieldDescriptor]),
445 ErrorResponse {
446 error: &'a (dyn std::error::Error + Send + Sync + 'static),
447 pretty: bool,
448 },
449 CloseComplete,
450
451 CopyOutResponse(usize),
453 CopyData(&'a Row),
454 CopyDone,
455
456 BackendKeyData((i32, i32)),
458}
459
460#[derive(Debug, Copy, Clone)]
461pub enum BeParameterStatusMessage<'a> {
462 ClientEncoding(&'a str),
463 StandardConformingString(&'a str),
464 ServerVersion(&'a str),
465 ApplicationName(&'a str),
466 TimeZone(&'a str),
467}
468
469#[derive(Debug, Copy, Clone)]
470pub struct BeCommandCompleteMessage {
471 pub stmt_type: StatementType,
472 pub rows_cnt: i32,
473}
474
475#[derive(Debug, Clone, Copy)]
476pub enum TransactionStatus {
477 Idle,
478 InTransaction,
479 InFailedTransaction,
480}
481
482impl BeMessage<'_> {
483 pub fn write(buf: &mut BytesMut, message: BeMessage<'_>) -> Result<()> {
485 match message {
486 BeMessage::AuthenticationOk => {
491 buf.put_u8(b'R');
492 buf.put_i32(8);
493 buf.put_i32(0);
494 }
495
496 BeMessage::AuthenticationCleartextPassword => {
501 buf.put_u8(b'R');
502 buf.put_i32(8);
503 buf.put_i32(3);
504 }
505
506 BeMessage::AuthenticationMd5Password(salt) => {
514 buf.put_u8(b'R');
515 buf.put_i32(12);
516 buf.put_i32(5);
517 buf.put_slice(&salt[..]);
518 }
519
520 BeMessage::ParameterStatus(param) => {
525 use BeParameterStatusMessage::*;
526 let [name, value] = match param {
527 ClientEncoding(val) => [b"client_encoding", val.as_bytes()],
528 StandardConformingString(val) => {
529 [b"standard_conforming_strings", val.as_bytes()]
530 }
531 ServerVersion(val) => [b"server_version", val.as_bytes()],
532 ApplicationName(val) => [b"application_name", val.as_bytes()],
533 TimeZone(val) => [b"TimeZone", val.as_bytes()],
535 };
536
537 let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
539 let mut buffer = vec![];
540 let cnt = buffer.write_vectored(iov).unwrap();
541
542 buf.put_u8(b'S');
543 write_body(buf, |stream| {
544 stream.put_slice(&buffer[..cnt]);
545 Ok(())
546 })
547 .unwrap();
548 }
549
550 BeMessage::CommandComplete(cmd) => {
555 let rows_cnt = cmd.rows_cnt;
556 let mut stmt_type = cmd.stmt_type;
557 let mut tag = "".to_owned();
558 stmt_type = match stmt_type {
559 StatementType::INSERT_RETURNING => StatementType::INSERT,
560 StatementType::DELETE_RETURNING => StatementType::DELETE,
561 StatementType::UPDATE_RETURNING => StatementType::UPDATE,
562 s => s,
563 };
564 tag.push_str(&stmt_type.to_string());
565 if stmt_type == StatementType::INSERT {
566 tag.push_str(" 0");
567 }
568 if stmt_type.is_command() {
569 tag.push(' ');
570 tag.push_str(&rows_cnt.to_string());
571 }
572 buf.put_u8(b'C');
573 write_body(buf, |buf| {
574 write_cstr(buf, tag.as_bytes())?;
575 Ok(())
576 })?;
577 }
578
579 BeMessage::NoticeResponse(notice) => {
586 buf.put_u8(b'N');
587 write_err_or_notice(buf, &ErrorOrNoticeMessage::notice(notice))?;
588 }
589
590 BeMessage::DataRow(vals) => {
599 buf.put_u8(b'D');
600 write_body(buf, |buf| {
601 buf.put_u16(vals.len() as u16); for val_opt in vals.values() {
603 if let Some(val) = val_opt {
604 buf.put_u32(val.len() as u32);
605 buf.put_slice(val);
606 } else {
607 buf.put_i32(-1);
608 }
609 }
610 Ok(())
611 })
612 .unwrap();
613 }
614
615 BeMessage::RowDescription(row_descs) => {
629 buf.put_u8(b'T');
630 write_body(buf, |buf| {
631 buf.put_i16(row_descs.len() as i16); for pg_field in row_descs {
633 write_cstr(buf, pg_field.get_name().as_bytes())?;
634 buf.put_i32(pg_field.get_table_oid()); buf.put_i16(pg_field.get_col_attr_num()); buf.put_i32(pg_field.get_type_oid());
637 buf.put_i16(pg_field.get_type_len());
638 buf.put_i32(pg_field.get_type_modifier()); buf.put_i16(pg_field.get_format_code()); }
641 Ok(())
642 })?;
643 }
644 BeMessage::ReadyForQuery(txn_status) => {
649 buf.put_u8(b'Z');
650 buf.put_i32(5);
651 buf.put_u8(match txn_status {
653 TransactionStatus::Idle => b'I',
654 TransactionStatus::InTransaction => b'T',
655 TransactionStatus::InFailedTransaction => b'E',
656 });
657 }
658
659 BeMessage::ParseComplete => {
660 buf.put_u8(b'1');
661 write_body(buf, |_| Ok(()))?;
662 }
663
664 BeMessage::BindComplete => {
665 buf.put_u8(b'2');
666 write_body(buf, |_| Ok(()))?;
667 }
668
669 BeMessage::CloseComplete => {
670 buf.put_u8(b'3');
671 write_body(buf, |_| Ok(()))?;
672 }
673
674 BeMessage::PortalSuspended => {
675 buf.put_u8(b's');
676 write_body(buf, |_| Ok(()))?;
677 }
678 BeMessage::ParameterDescription(para_descs) => {
683 buf.put_u8(b't');
684 write_body(buf, |buf| {
685 buf.put_i16(para_descs.len() as i16);
686 for oid in para_descs {
687 buf.put_i32(*oid);
688 }
689 Ok(())
690 })?;
691 }
692
693 BeMessage::NoData => {
694 buf.put_u8(b'n');
695 write_body(buf, |_| Ok(())).unwrap();
696 }
697
698 BeMessage::EncryptionResponseSsl => {
699 buf.put_u8(b'S');
700 }
701
702 BeMessage::EncryptionResponseGss => {
703 buf.put_u8(b'G');
704 }
705
706 BeMessage::EncryptionResponseNo => {
707 buf.put_u8(b'N');
708 }
709
710 BeMessage::EmptyQueryResponse => {
715 buf.put_u8(b'I');
716 buf.put_i32(4);
717 }
718
719 BeMessage::ErrorResponse { error, pretty } => {
720 buf.put_u8(b'E');
722 write_err_or_notice(buf, &ErrorOrNoticeMessage::error(error, pretty))?;
724 }
725
726 BeMessage::BackendKeyData((process_id, secret_key)) => {
727 buf.put_u8(b'K');
728 write_body(buf, |buf| {
729 buf.put_i32(process_id);
730 buf.put_i32(secret_key);
731 Ok(())
732 })?;
733 }
734 BeMessage::CopyOutResponse(col_num) => {
735 buf.put_u8(b'H');
736 write_body(buf, |buf| {
737 buf.put_i8(Format::Text.to_i8());
738 buf.put_i16(col_num as _);
739 for _ in 0..col_num {
740 buf.put_i16(Format::Text.to_i8() as _);
741 }
742 Ok(())
743 })?;
744 }
745 BeMessage::CopyData(row) => {
746 buf.put_u8(b'd');
747 write_body(buf, |buf| {
749 fn write_str_bytes(
750 buf: &mut BytesMut,
751 str_bytes: &Option<Bytes>,
752 ) -> Result<()> {
753 let Some(str_bytes) = str_bytes else {
754 return Ok(());
755 };
756 let s = String::from_utf8_lossy(str_bytes);
757 for c in s.as_str().chars() {
758 match c {
761 '\t' => {
762 buf.put_slice(b"\\t");
763 }
764 '\n' => {
765 buf.put_slice(b"\\n");
766 }
767 '\r' => {
768 buf.put_slice(b"\\r");
769 }
770 '\\' => {
771 buf.put_slice(b"\\\\");
772 }
773 _ => {
774 std::fmt::Write::write_char(buf, c).map_err(|_| {
775 Error::other(anyhow!("failed to write_char [{c}]"))
776 })?;
777 }
778 }
779 }
780 Ok(())
781 }
782 match row.values() {
783 [] => {}
784 [first, rest @ ..] => {
785 write_str_bytes(buf, first)?;
786
787 for rest in rest {
788 buf.put_u8(b'\t');
789 write_str_bytes(buf, rest)?;
790 }
791 }
792 }
793 buf.put_u8(b'\n');
794 Ok(())
795 })?;
796 }
797 BeMessage::CopyDone => {
798 buf.put_u8(b'c');
799 write_body(buf, |_| Ok(()))?;
800 }
801 }
802
803 Ok(())
804 }
805}
806
807trait FromUsize: Sized {
809 fn from_usize(x: usize) -> Result<Self>;
810}
811
812macro_rules! from_usize {
813 ($t:ty) => {
814 impl FromUsize for $t {
815 #[inline]
816 fn from_usize(x: usize) -> Result<$t> {
817 if x > <$t>::MAX as usize {
818 Err(Error::new(ErrorKind::InvalidInput, "value too large to transmit").into())
819 } else {
820 Ok(x as $t)
821 }
822 }
823 }
824 };
825}
826
827from_usize!(i32);
828
829fn write_body<F>(buf: &mut BytesMut, f: F) -> Result<()>
833where
834 F: FnOnce(&mut BytesMut) -> Result<()>,
835{
836 let base = buf.len();
837 buf.extend_from_slice(&[0; 4]);
838
839 f(buf)?;
840
841 let size = i32::from_usize(buf.len() - base)?;
842 BigEndian::write_i32(&mut buf[base..], size);
843 Ok(())
844}
845
846fn write_cstr(buf: &mut BytesMut, s: &[u8]) -> Result<()> {
848 if s.contains(&0) {
849 return Err(Error::new(
850 ErrorKind::InvalidInput,
851 "string contains embedded null",
852 ));
853 }
854 buf.put_slice(s);
855 buf.put_u8(0);
856 Ok(())
857}
858
859fn write_err_or_notice(buf: &mut BytesMut, msg: &ErrorOrNoticeMessage<'_>) -> Result<()> {
861 write_body(buf, |buf| {
862 buf.put_u8(b'S'); write_cstr(buf, msg.severity.as_str().as_bytes())?;
864
865 buf.put_u8(b'C'); write_cstr(buf, msg.error_code.sqlstate().as_bytes())?;
867
868 buf.put_u8(b'M'); write_cstr(buf, msg.message.as_bytes())?;
870
871 buf.put_u8(0); Ok(())
873 })
874}
875
876#[cfg(test)]
877mod tests {
878 use bytes::Bytes;
879
880 use crate::pg_message::FeQueryMessage;
881
882 #[test]
883 fn test_get_sql() {
884 let fe = FeQueryMessage {
885 sql_bytes: Bytes::from(vec![255, 255, 255, 255, 255, 255, 0]),
886 };
887 assert!(fe.get_sql().is_err(), "{}", true);
888 let fe = FeQueryMessage {
889 sql_bytes: Bytes::from(vec![1, 2, 3, 4, 5, 6, 7, 8]),
890 };
891 assert!(fe.get_sql().is_err(), "{}", true);
892 }
893}