1use std::collections::HashMap;
16use std::ffi::CStr;
17use std::io::{Error, ErrorKind, IoSlice, Result, Write};
18
19use anyhow::anyhow;
20use byteorder::{BigEndian, ByteOrder};
21use bytes::{Buf, BufMut, Bytes, BytesMut};
23use peekable::tokio::AsyncPeekable;
24use tokio::io::{AsyncRead, AsyncReadExt};
25
26use crate::error_or_notice::ErrorOrNoticeMessage;
27use crate::pg_field_descriptor::PgFieldDescriptor;
28use crate::pg_response::StatementType;
29use crate::pg_server::BoxedError;
30use crate::types::Row;
31
32#[derive(Debug)]
34pub enum FeMessage {
35 Ssl,
36 Gss,
37 Startup(FeStartupMessage),
38 Query(FeQueryMessage),
39 Parse(FeParseMessage),
40 Password(FePasswordMessage),
41 Describe(FeDescribeMessage),
42 Bind(FeBindMessage),
43 Execute(FeExecuteMessage),
44 Close(FeCloseMessage),
45 Sync,
46 CancelQuery(FeCancelMessage),
47 Terminate,
48 Flush,
49 HealthCheck,
51 ServerThrottle(ServerThrottleReason),
53}
54
55#[derive(Debug)]
56pub enum ServerThrottleReason {
57 TooLargeMessage,
58 TooManyMemoryUsage,
59}
60
61#[derive(Debug)]
62pub struct FeStartupMessage {
63 pub config: HashMap<String, String>,
64}
65
66impl FeStartupMessage {
67 pub fn build_with_payload(payload: &[u8]) -> Result<Self> {
68 let config = match std::str::from_utf8(payload) {
69 Ok(v) => Ok(v.trim_end_matches('\0')),
70 Err(err) => Err(Error::new(
71 ErrorKind::InvalidInput,
72 anyhow!(err).context("Input end error"),
73 )),
74 }?;
75 let mut map = HashMap::new();
76 let config: Vec<&str> = config.split('\0').collect();
77 if config.len() % 2 == 1 {
78 return Err(Error::new(
79 ErrorKind::InvalidInput,
80 "Invalid input config: odd number of config pairs",
81 ));
82 }
83 config.chunks(2).for_each(|chunk| {
84 map.insert(chunk[0].to_owned(), chunk[1].to_owned());
85 });
86 Ok(FeStartupMessage { config: map })
87 }
88}
89
90#[derive(Debug)]
92pub struct FeQueryMessage {
93 pub sql_bytes: Bytes,
94}
95
96#[derive(Debug)]
97pub struct FeBindMessage {
98 pub param_format_codes: Vec<i16>,
99 pub result_format_codes: Vec<i16>,
100
101 pub params: Vec<Option<Bytes>>,
102 pub portal_name: Bytes,
103 pub statement_name: Bytes,
104}
105
106#[derive(Debug)]
107pub struct FeExecuteMessage {
108 pub portal_name: Bytes,
109 pub max_rows: i32,
110}
111
112#[derive(Debug)]
113pub struct FeParseMessage {
114 pub statement_name: Bytes,
115 pub sql_bytes: Bytes,
116 pub type_ids: Vec<i32>,
117}
118
119#[derive(Debug)]
120pub struct FePasswordMessage {
121 pub password: Bytes,
122}
123
124#[derive(Debug)]
125pub struct FeDescribeMessage {
126 pub kind: u8,
128 pub name: Bytes,
129}
130
131#[derive(Debug)]
132pub struct FeCloseMessage {
133 pub kind: u8,
134 pub name: Bytes,
135}
136
137#[derive(Debug)]
138pub struct FeCancelMessage {
139 pub target_process_id: i32,
140 pub target_secret_key: i32,
141}
142
143impl FeCancelMessage {
144 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
145 let target_process_id = buf.get_i32();
146 let target_secret_key = buf.get_i32();
147 Ok(FeMessage::CancelQuery(Self {
148 target_process_id,
149 target_secret_key,
150 }))
151 }
152}
153
154impl FeDescribeMessage {
155 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
156 let kind = buf.get_u8();
157 let name = read_null_terminated(&mut buf)?;
158
159 Ok(FeMessage::Describe(FeDescribeMessage { kind, name }))
160 }
161}
162
163impl FeBindMessage {
164 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
182 let portal_name = read_null_terminated(&mut buf)?;
183 let statement_name = read_null_terminated(&mut buf)?;
184
185 let len = buf.get_i16();
186 let param_format_codes = (0..len).map(|_| buf.get_i16()).collect();
187
188 let len = buf.get_i16();
190 let params = (0..len)
191 .map(|_| {
192 let val_len = buf.get_i32();
193 if val_len == -1 {
194 None
195 } else {
196 Some(buf.copy_to_bytes(val_len as usize))
197 }
198 })
199 .collect();
200
201 let len = buf.get_i16();
202 let result_format_codes = (0..len).map(|_| buf.get_i16()).collect();
203
204 Ok(FeMessage::Bind(FeBindMessage {
205 param_format_codes,
206 result_format_codes,
207 params,
208 portal_name,
209 statement_name,
210 }))
211 }
212}
213
214impl FeExecuteMessage {
215 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
216 let portal_name = read_null_terminated(&mut buf)?;
217 let max_rows = buf.get_i32();
218
219 Ok(FeMessage::Execute(FeExecuteMessage {
220 portal_name,
221 max_rows,
222 }))
223 }
224}
225
226impl FeParseMessage {
227 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
228 let statement_name = read_null_terminated(&mut buf)?;
229 let sql_bytes = read_null_terminated(&mut buf)?;
230 let nparams = buf.get_i16();
231
232 let type_ids: Vec<i32> = (0..nparams).map(|_| buf.get_i32()).collect();
233
234 Ok(FeMessage::Parse(FeParseMessage {
235 statement_name,
236 sql_bytes,
237 type_ids,
238 }))
239 }
240}
241
242impl FePasswordMessage {
243 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
244 let password = read_null_terminated(&mut buf)?;
245
246 Ok(FeMessage::Password(FePasswordMessage { password }))
247 }
248}
249
250impl FeQueryMessage {
251 pub fn get_sql(&self) -> Result<&str> {
252 match CStr::from_bytes_with_nul(&self.sql_bytes) {
253 Ok(cstr) => cstr.to_str().map_err(|err| {
254 Error::new(
255 ErrorKind::InvalidInput,
256 anyhow!(err).context("Invalid UTF-8 sequence"),
257 )
258 }),
259 Err(err) => Err(Error::new(
260 ErrorKind::InvalidInput,
261 anyhow!(err).context("Input end error"),
262 )),
263 }
264 }
265}
266
267impl FeCloseMessage {
268 pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
269 let kind = buf.get_u8();
270 let name = read_null_terminated(&mut buf)?;
271 Ok(FeMessage::Close(FeCloseMessage { kind, name }))
272 }
273}
274
275#[derive(Clone)]
276pub struct FeMessageHeader {
277 pub tag: u8,
278 pub payload_len: i32,
279}
280
281impl FeMessage {
282 pub async fn read_header(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessageHeader> {
284 let tag = stream.read_u8().await?;
285 let len = stream.read_i32().await?;
286
287 let payload_len = len - 4;
288 Ok(FeMessageHeader { tag, payload_len })
289 }
290
291 pub async fn read_body(
293 stream: &mut (impl AsyncRead + Unpin),
294 header: FeMessageHeader,
295 ) -> Result<FeMessage> {
296 let FeMessageHeader { tag, payload_len } = header;
297 let mut payload: Vec<u8> = vec![0; payload_len as usize];
298 if payload_len > 0 {
299 stream.read_exact(&mut payload).await?;
300 }
301 let sql_bytes = Bytes::from(payload);
302 match tag {
303 b'Q' => Ok(FeMessage::Query(FeQueryMessage { sql_bytes })),
304 b'P' => FeParseMessage::parse(sql_bytes),
305 b'D' => FeDescribeMessage::parse(sql_bytes),
306 b'B' => FeBindMessage::parse(sql_bytes),
307 b'E' => FeExecuteMessage::parse(sql_bytes),
308 b'S' => Ok(FeMessage::Sync),
309 b'X' => Ok(FeMessage::Terminate),
310 b'C' => FeCloseMessage::parse(sql_bytes),
311 b'p' => FePasswordMessage::parse(sql_bytes),
312 b'H' => Ok(FeMessage::Flush),
313 _ => Err(std::io::Error::new(
314 ErrorKind::InvalidInput,
315 format!("Unsupported tag of regular message: {}", tag),
316 )),
317 }
318 }
319
320 pub async fn skip_body(
321 stream: &mut (impl AsyncRead + Unpin),
322 header: FeMessageHeader,
323 ) -> Result<()> {
324 let FeMessageHeader {
325 tag: _,
326 payload_len,
327 } = header;
328
329 if payload_len > 0 {
330 const BUF_SIZE: usize = 1024;
332 let mut buf: Vec<u8> = vec![0; BUF_SIZE];
333 for _ in 0..(payload_len as usize) / BUF_SIZE {
334 stream.read_exact(&mut buf).await?;
335 }
336 let remain = (payload_len as usize) % BUF_SIZE;
337 if remain > 0 {
338 buf.truncate(remain);
339 stream.read_exact(&mut buf).await?;
340 }
341 }
342 Ok(())
343 }
344}
345
346impl FeStartupMessage {
347 pub async fn read(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessage> {
349 let mut stream = AsyncPeekable::new(stream);
350
351 if let Err(err) = stream.peek_exact(&mut [0; 1]).await {
352 if err.kind() == ErrorKind::UnexpectedEof {
354 return Ok(FeMessage::HealthCheck);
355 } else {
356 return Err(err);
357 }
358 }
359
360 let len = stream.read_i32().await?;
361 let protocol_num = stream.read_i32().await?;
362 let payload_len = (len - 8) as usize;
363 if payload_len >= isize::MAX as usize {
364 return Err(std::io::Error::new(
365 ErrorKind::InvalidInput,
366 format!("Payload length has exceed usize::MAX {:?}", payload_len),
367 ));
368 }
369 let mut payload = vec![0; payload_len];
370 if payload_len > 0 {
371 stream.read_exact(&mut payload).await?;
372 }
373 match protocol_num {
374 196608 => Ok(FeMessage::Startup(FeStartupMessage::build_with_payload(
376 &payload,
377 )?)),
378 80877104 => Ok(FeMessage::Gss),
379 80877103 => Ok(FeMessage::Ssl),
380 80877102 => FeCancelMessage::parse(Bytes::from(payload)),
382 _ => Err(std::io::Error::new(
383 ErrorKind::InvalidInput,
384 format!(
385 "Unsupported protocol number in start up msg {:?}",
386 protocol_num
387 ),
388 )),
389 }
390 }
391}
392
393fn read_null_terminated(buf: &mut Bytes) -> Result<Bytes> {
395 let mut result = BytesMut::new();
396
397 loop {
398 if !buf.has_remaining() {
399 panic!("no null-terminator in string");
400 }
401
402 let byte = buf.get_u8();
403
404 if byte == 0 {
405 break;
406 }
407 result.put_u8(byte);
408 }
409 Ok(result.freeze())
410}
411
412#[derive(Debug)]
416pub enum BeMessage<'a> {
417 AuthenticationOk,
418 AuthenticationCleartextPassword,
419 AuthenticationMd5Password(&'a [u8; 4]),
420 CommandComplete(BeCommandCompleteMessage),
421 NoticeResponse(&'a str),
422 EncryptionResponseSsl,
424 EncryptionResponseGss,
425 EncryptionResponseNo,
426 EmptyQueryResponse,
427 ParseComplete,
428 BindComplete,
429 PortalSuspended,
430 ParameterDescription(&'a [i32]),
432 NoData,
433 DataRow(&'a Row),
434 ParameterStatus(BeParameterStatusMessage<'a>),
435 ReadyForQuery(TransactionStatus),
436 RowDescription(&'a [PgFieldDescriptor]),
437 ErrorResponse(BoxedError),
438 CloseComplete,
439
440 BackendKeyData((i32, i32)),
442}
443
444#[derive(Debug)]
445pub enum BeParameterStatusMessage<'a> {
446 ClientEncoding(&'a str),
447 StandardConformingString(&'a str),
448 ServerVersion(&'a str),
449 ApplicationName(&'a str),
450}
451
452#[derive(Debug)]
453pub struct BeCommandCompleteMessage {
454 pub stmt_type: StatementType,
455 pub rows_cnt: i32,
456}
457
458#[derive(Debug, Clone, Copy)]
459pub enum TransactionStatus {
460 Idle,
461 InTransaction,
462 InFailedTransaction,
463}
464
465impl BeMessage<'_> {
466 pub fn write(buf: &mut BytesMut, message: &BeMessage<'_>) -> Result<()> {
468 match message {
469 BeMessage::AuthenticationOk => {
474 buf.put_u8(b'R');
475 buf.put_i32(8);
476 buf.put_i32(0);
477 }
478
479 BeMessage::AuthenticationCleartextPassword => {
484 buf.put_u8(b'R');
485 buf.put_i32(8);
486 buf.put_i32(3);
487 }
488
489 BeMessage::AuthenticationMd5Password(salt) => {
497 buf.put_u8(b'R');
498 buf.put_i32(12);
499 buf.put_i32(5);
500 buf.put_slice(&salt[..]);
501 }
502
503 BeMessage::ParameterStatus(param) => {
508 use BeParameterStatusMessage::*;
509 let [name, value] = match param {
510 ClientEncoding(val) => [b"client_encoding", val.as_bytes()],
511 StandardConformingString(val) => {
512 [b"standard_conforming_strings", val.as_bytes()]
513 }
514 ServerVersion(val) => [b"server_version", val.as_bytes()],
515 ApplicationName(val) => [b"application_name", val.as_bytes()],
516 };
517
518 let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
520 let mut buffer = vec![];
521 let cnt = buffer.write_vectored(iov).unwrap();
522
523 buf.put_u8(b'S');
524 write_body(buf, |stream| {
525 stream.put_slice(&buffer[..cnt]);
526 Ok(())
527 })
528 .unwrap();
529 }
530
531 BeMessage::CommandComplete(cmd) => {
536 let rows_cnt = cmd.rows_cnt;
537 let mut stmt_type = cmd.stmt_type;
538 let mut tag = "".to_owned();
539 stmt_type = match stmt_type {
540 StatementType::INSERT_RETURNING => StatementType::INSERT,
541 StatementType::DELETE_RETURNING => StatementType::DELETE,
542 StatementType::UPDATE_RETURNING => StatementType::UPDATE,
543 s => s,
544 };
545 tag.push_str(&stmt_type.to_string());
546 if stmt_type == StatementType::INSERT {
547 tag.push_str(" 0");
548 }
549 if stmt_type.is_command() {
550 tag.push(' ');
551 tag.push_str(&rows_cnt.to_string());
552 }
553 buf.put_u8(b'C');
554 write_body(buf, |buf| {
555 write_cstr(buf, tag.as_bytes())?;
556 Ok(())
557 })?;
558 }
559
560 BeMessage::NoticeResponse(notice) => {
567 buf.put_u8(b'N');
568 write_err_or_notice(buf, &ErrorOrNoticeMessage::notice(notice))?;
569 }
570
571 BeMessage::DataRow(vals) => {
580 buf.put_u8(b'D');
581 write_body(buf, |buf| {
582 buf.put_u16(vals.len() as u16); for val_opt in vals.values() {
584 if let Some(val) = val_opt {
585 buf.put_u32(val.len() as u32);
586 buf.put_slice(val);
587 } else {
588 buf.put_i32(-1);
589 }
590 }
591 Ok(())
592 })
593 .unwrap();
594 }
595
596 BeMessage::RowDescription(row_descs) => {
610 buf.put_u8(b'T');
611 write_body(buf, |buf| {
612 buf.put_i16(row_descs.len() as i16); for pg_field in *row_descs {
614 write_cstr(buf, pg_field.get_name().as_bytes())?;
615 buf.put_i32(pg_field.get_table_oid()); buf.put_i16(pg_field.get_col_attr_num()); buf.put_i32(pg_field.get_type_oid());
618 buf.put_i16(pg_field.get_type_len());
619 buf.put_i32(pg_field.get_type_modifier()); buf.put_i16(pg_field.get_format_code()); }
622 Ok(())
623 })?;
624 }
625 BeMessage::ReadyForQuery(txn_status) => {
630 buf.put_u8(b'Z');
631 buf.put_i32(5);
632 buf.put_u8(match txn_status {
634 TransactionStatus::Idle => b'I',
635 TransactionStatus::InTransaction => b'T',
636 TransactionStatus::InFailedTransaction => b'E',
637 });
638 }
639
640 BeMessage::ParseComplete => {
641 buf.put_u8(b'1');
642 write_body(buf, |_| Ok(()))?;
643 }
644
645 BeMessage::BindComplete => {
646 buf.put_u8(b'2');
647 write_body(buf, |_| Ok(()))?;
648 }
649
650 BeMessage::CloseComplete => {
651 buf.put_u8(b'3');
652 write_body(buf, |_| Ok(()))?;
653 }
654
655 BeMessage::PortalSuspended => {
656 buf.put_u8(b's');
657 write_body(buf, |_| Ok(()))?;
658 }
659 BeMessage::ParameterDescription(para_descs) => {
664 buf.put_u8(b't');
665 write_body(buf, |buf| {
666 buf.put_i16(para_descs.len() as i16);
667 for oid in *para_descs {
668 buf.put_i32(*oid);
669 }
670 Ok(())
671 })?;
672 }
673
674 BeMessage::NoData => {
675 buf.put_u8(b'n');
676 write_body(buf, |_| Ok(())).unwrap();
677 }
678
679 BeMessage::EncryptionResponseSsl => {
680 buf.put_u8(b'S');
681 }
682
683 BeMessage::EncryptionResponseGss => {
684 buf.put_u8(b'G');
685 }
686
687 BeMessage::EncryptionResponseNo => {
688 buf.put_u8(b'N');
689 }
690
691 BeMessage::EmptyQueryResponse => {
696 buf.put_u8(b'I');
697 buf.put_i32(4);
698 }
699
700 BeMessage::ErrorResponse(error) => {
701 use thiserror_ext::AsReport;
702 buf.put_u8(b'E');
707 let msg = error.to_report_string_pretty();
709 write_err_or_notice(buf, &ErrorOrNoticeMessage::internal_error(&msg))?;
710 }
711
712 BeMessage::BackendKeyData((process_id, secret_key)) => {
713 buf.put_u8(b'K');
714 write_body(buf, |buf| {
715 buf.put_i32(*process_id);
716 buf.put_i32(*secret_key);
717 Ok(())
718 })?;
719 }
720 }
721
722 Ok(())
723 }
724}
725
726trait FromUsize: Sized {
728 fn from_usize(x: usize) -> Result<Self>;
729}
730
731macro_rules! from_usize {
732 ($t:ty) => {
733 impl FromUsize for $t {
734 #[inline]
735 fn from_usize(x: usize) -> Result<$t> {
736 if x > <$t>::MAX as usize {
737 Err(Error::new(ErrorKind::InvalidInput, "value too large to transmit").into())
738 } else {
739 Ok(x as $t)
740 }
741 }
742 }
743 };
744}
745
746from_usize!(i32);
747
748fn write_body<F>(buf: &mut BytesMut, f: F) -> Result<()>
752where
753 F: FnOnce(&mut BytesMut) -> Result<()>,
754{
755 let base = buf.len();
756 buf.extend_from_slice(&[0; 4]);
757
758 f(buf)?;
759
760 let size = i32::from_usize(buf.len() - base)?;
761 BigEndian::write_i32(&mut buf[base..], size);
762 Ok(())
763}
764
765fn write_cstr(buf: &mut BytesMut, s: &[u8]) -> Result<()> {
767 if s.contains(&0) {
768 return Err(Error::new(
769 ErrorKind::InvalidInput,
770 "string contains embedded null",
771 ));
772 }
773 buf.put_slice(s);
774 buf.put_u8(0);
775 Ok(())
776}
777
778fn write_err_or_notice(buf: &mut BytesMut, msg: &ErrorOrNoticeMessage<'_>) -> Result<()> {
780 write_body(buf, |buf| {
781 buf.put_u8(b'S'); write_cstr(buf, msg.severity.as_str().as_bytes())?;
783
784 buf.put_u8(b'C'); write_cstr(buf, msg.state.code().as_bytes())?;
786
787 buf.put_u8(b'M'); write_cstr(buf, msg.message.as_bytes())?;
789
790 buf.put_u8(0); Ok(())
792 })
793}
794
795#[cfg(test)]
796mod tests {
797 use bytes::Bytes;
798
799 use crate::pg_message::FeQueryMessage;
800
801 #[test]
802 fn test_get_sql() {
803 let fe = FeQueryMessage {
804 sql_bytes: Bytes::from(vec![255, 255, 255, 255, 255, 255, 0]),
805 };
806 assert!(fe.get_sql().is_err(), "{}", true);
807 let fe = FeQueryMessage {
808 sql_bytes: Bytes::from(vec![1, 2, 3, 4, 5, 6, 7, 8]),
809 };
810 assert!(fe.get_sql().is_err(), "{}", true);
811 }
812}