pgwire/
pg_message.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ffi::CStr;
17use std::io::{Error, ErrorKind, IoSlice, Result, Write};
18
19use anyhow::anyhow;
20use byteorder::{BigEndian, ByteOrder};
21/// Part of code learned from <https://github.com/zenithdb/zenith/blob/main/zenith_utils/src/pq_proto.rs>.
22use bytes::{Buf, BufMut, Bytes, BytesMut};
23use peekable::tokio::AsyncPeekable;
24use tokio::io::{AsyncRead, AsyncReadExt};
25
26use crate::error_or_notice::{ErrorOrNoticeMessage, Severity};
27use crate::pg_field_descriptor::PgFieldDescriptor;
28use crate::pg_response::StatementType;
29use crate::types::{Format, Row};
30
31/// Messages that can be sent from pg client to server. Implement `read`.
32#[derive(Debug)]
33pub enum FeMessage {
34    Ssl,
35    Gss,
36    Startup(FeStartupMessage),
37    Query(FeQueryMessage),
38    Parse(FeParseMessage),
39    Password(FePasswordMessage),
40    Describe(FeDescribeMessage),
41    Bind(FeBindMessage),
42    Execute(FeExecuteMessage),
43    Close(FeCloseMessage),
44    Sync,
45    CancelQuery(FeCancelMessage),
46    Terminate,
47    Flush,
48    // special msg to detect health check, which represents the client immediately closes the connection cleanly without sending any data.
49    HealthCheck,
50    // The original message has been rejected due to server throttling. This is a placeholder message generated by server.
51    ServerThrottle(ServerThrottleReason),
52}
53
54impl FeMessage {
55    pub fn get_sql(&self) -> Result<Option<&str>> {
56        match self {
57            FeMessage::Query(q) => Ok(Some(q.get_sql()?)),
58            FeMessage::Parse(p) => Ok(Some(p.get_sql()?)),
59            _ => Ok(None),
60        }
61    }
62}
63
64#[derive(Debug)]
65pub enum ServerThrottleReason {
66    TooLargeMessage,
67    TooManyMemoryUsage,
68}
69
70#[derive(Debug)]
71pub struct FeStartupMessage {
72    pub config: HashMap<String, String>,
73}
74
75impl FeStartupMessage {
76    pub fn build_with_payload(payload: &[u8]) -> Result<Self> {
77        let config = match std::str::from_utf8(payload) {
78            Ok(v) => Ok(v.strip_suffix('\0').unwrap_or(v)),
79            Err(err) => Err(Error::new(
80                ErrorKind::InvalidInput,
81                anyhow!(err).context("Input end error"),
82            )),
83        }?;
84        let mut map = HashMap::new();
85        let config: Vec<&str> = config.split_terminator('\0').collect();
86        if config.len() % 2 == 1 {
87            return Err(Error::new(
88                ErrorKind::InvalidInput,
89                "Invalid input config: odd number of config pairs",
90            ));
91        }
92        config.chunks(2).for_each(|chunk| {
93            map.insert(chunk[0].to_owned(), chunk[1].to_owned());
94        });
95        Ok(FeStartupMessage { config: map })
96    }
97}
98
99/// Query message contains the string sql.
100#[derive(Debug)]
101pub struct FeQueryMessage {
102    pub sql_bytes: Bytes,
103}
104
105#[derive(Debug)]
106pub struct FeBindMessage {
107    pub param_format_codes: Vec<i16>,
108    pub result_format_codes: Vec<i16>,
109
110    pub params: Vec<Option<Bytes>>,
111    pub portal_name: Bytes,
112    pub statement_name: Bytes,
113}
114
115#[derive(Debug)]
116pub struct FeExecuteMessage {
117    pub portal_name: Bytes,
118    pub max_rows: i32,
119}
120
121#[derive(Debug)]
122pub struct FeParseMessage {
123    pub statement_name: Bytes,
124    pub sql_bytes: Bytes,
125    pub type_ids: Vec<i32>,
126}
127
128#[derive(Debug)]
129pub struct FePasswordMessage {
130    pub password: Bytes,
131}
132
133#[derive(Debug)]
134pub struct FeDescribeMessage {
135    // 'S' to describe a prepared statement; or 'P' to describe a portal.
136    pub kind: u8,
137    pub name: Bytes,
138}
139
140#[derive(Debug)]
141pub struct FeCloseMessage {
142    pub kind: u8,
143    pub name: Bytes,
144}
145
146#[derive(Debug)]
147pub struct FeCancelMessage {
148    pub target_process_id: i32,
149    pub target_secret_key: i32,
150}
151
152impl FeCancelMessage {
153    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
154        let target_process_id = buf.get_i32();
155        let target_secret_key = buf.get_i32();
156        Ok(FeMessage::CancelQuery(Self {
157            target_process_id,
158            target_secret_key,
159        }))
160    }
161}
162
163impl FeDescribeMessage {
164    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
165        let kind = buf.get_u8();
166        let name = read_null_terminated(&mut buf)?;
167
168        Ok(FeMessage::Describe(FeDescribeMessage { kind, name }))
169    }
170}
171
172impl FeBindMessage {
173    // Bind Message Header
174    // +-----+-----------+
175    // | 'B' | int32 len |
176    // +-----+-----------+
177    // Bind Message Body
178    // +----------------+---------------+
179    // | str portalname | str statement |
180    // +----------------+---------------+
181    // +---------------------+------------------+-------+
182    // | int16 numFormatCode | int16 FormatCode |  ...  |
183    // +---------------------+------------------+-------+
184    // +-----------------+-------------------+---------------+
185    // | int16 numParams | int32 valueLength |  byte value.. |
186    // +-----------------+-------------------+---------------+
187    // +----------------------------------+------------------+-------+
188    // | int16 numResultColumnFormatCodes | int16 FormatCode |  ...  |
189    // +----------------------------------+------------------+-------+
190    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
191        let portal_name = read_null_terminated(&mut buf)?;
192        let statement_name = read_null_terminated(&mut buf)?;
193
194        let len = buf.get_i16();
195        let param_format_codes = (0..len).map(|_| buf.get_i16()).collect();
196
197        // Read Params
198        let len = buf.get_i16();
199        let params = (0..len)
200            .map(|_| {
201                let val_len = buf.get_i32();
202                if val_len == -1 {
203                    None
204                } else {
205                    Some(buf.copy_to_bytes(val_len as usize))
206                }
207            })
208            .collect();
209
210        let len = buf.get_i16();
211        let result_format_codes = (0..len).map(|_| buf.get_i16()).collect();
212
213        Ok(FeMessage::Bind(FeBindMessage {
214            param_format_codes,
215            result_format_codes,
216            params,
217            portal_name,
218            statement_name,
219        }))
220    }
221}
222
223impl FeExecuteMessage {
224    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
225        let portal_name = read_null_terminated(&mut buf)?;
226        let max_rows = buf.get_i32();
227
228        Ok(FeMessage::Execute(FeExecuteMessage {
229            portal_name,
230            max_rows,
231        }))
232    }
233}
234
235impl FeParseMessage {
236    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
237        let statement_name = read_null_terminated(&mut buf)?;
238        let sql_bytes = read_null_terminated(&mut buf)?;
239        let nparams = buf.get_i16();
240
241        let type_ids: Vec<i32> = (0..nparams).map(|_| buf.get_i32()).collect();
242
243        Ok(FeMessage::Parse(FeParseMessage {
244            statement_name,
245            sql_bytes,
246            type_ids,
247        }))
248    }
249}
250
251impl FePasswordMessage {
252    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
253        let password = read_null_terminated(&mut buf)?;
254
255        Ok(FeMessage::Password(FePasswordMessage { password }))
256    }
257}
258
259impl FeQueryMessage {
260    pub fn get_sql(&self) -> Result<&str> {
261        get_sql_from_bytes(&self.sql_bytes)
262    }
263}
264
265impl FeParseMessage {
266    pub fn get_sql(&self) -> Result<&str> {
267        get_sql_from_bytes(&self.sql_bytes)
268    }
269}
270
271fn get_sql_from_bytes(sql_bytes: &[u8]) -> Result<&str> {
272    match CStr::from_bytes_with_nul(sql_bytes) {
273        Ok(cstr) => cstr.to_str().map_err(|err| {
274            Error::new(
275                ErrorKind::InvalidInput,
276                anyhow!(err).context("Invalid UTF-8 sequence"),
277            )
278        }),
279        Err(err) => Err(Error::new(
280            ErrorKind::InvalidInput,
281            anyhow!(err).context("Input end error"),
282        )),
283    }
284}
285
286impl FeCloseMessage {
287    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
288        let kind = buf.get_u8();
289        let name = read_null_terminated(&mut buf)?;
290        Ok(FeMessage::Close(FeCloseMessage { kind, name }))
291    }
292}
293
294#[derive(Clone)]
295pub struct FeMessageHeader {
296    pub tag: u8,
297    pub payload_len: i32,
298}
299
300impl FeMessage {
301    /// Read one message from the stream.
302    pub async fn read_header(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessageHeader> {
303        let tag = stream.read_u8().await?;
304        let len = stream.read_i32().await?;
305
306        let payload_len = len - 4;
307        Ok(FeMessageHeader { tag, payload_len })
308    }
309
310    /// Read one message from the stream.
311    pub async fn read_body(
312        stream: &mut (impl AsyncRead + Unpin),
313        header: FeMessageHeader,
314    ) -> Result<FeMessage> {
315        let FeMessageHeader { tag, payload_len } = header;
316        let mut payload: Vec<u8> = vec![0; payload_len as usize];
317        if payload_len > 0 {
318            stream.read_exact(&mut payload).await?;
319        }
320        let sql_bytes = Bytes::from(payload);
321        match tag {
322            b'Q' => Ok(FeMessage::Query(FeQueryMessage { sql_bytes })),
323            b'P' => FeParseMessage::parse(sql_bytes),
324            b'D' => FeDescribeMessage::parse(sql_bytes),
325            b'B' => FeBindMessage::parse(sql_bytes),
326            b'E' => FeExecuteMessage::parse(sql_bytes),
327            b'S' => Ok(FeMessage::Sync),
328            b'X' => Ok(FeMessage::Terminate),
329            b'C' => FeCloseMessage::parse(sql_bytes),
330            b'p' => FePasswordMessage::parse(sql_bytes),
331            b'H' => Ok(FeMessage::Flush),
332            _ => Err(std::io::Error::new(
333                ErrorKind::InvalidInput,
334                format!("Unsupported tag of regular message: {}", tag),
335            )),
336        }
337    }
338
339    pub async fn skip_body(
340        stream: &mut (impl AsyncRead + Unpin),
341        header: FeMessageHeader,
342    ) -> Result<()> {
343        let FeMessageHeader {
344            tag: _,
345            payload_len,
346        } = header;
347
348        if payload_len > 0 {
349            // Use smaller batches to process the payload instead of handling it all at once to minimize memory usage.
350            const BUF_SIZE: usize = 1024;
351            let mut buf: Vec<u8> = vec![0; BUF_SIZE];
352            for _ in 0..(payload_len as usize) / BUF_SIZE {
353                stream.read_exact(&mut buf).await?;
354            }
355            let remain = (payload_len as usize) % BUF_SIZE;
356            if remain > 0 {
357                buf.truncate(remain);
358                stream.read_exact(&mut buf).await?;
359            }
360        }
361        Ok(())
362    }
363}
364
365impl FeStartupMessage {
366    /// Read startup message from the stream.
367    pub async fn read(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessage> {
368        let mut stream = AsyncPeekable::new(stream);
369
370        if let Err(err) = stream.peek_exact(&mut [0; 1]).await {
371            // If the stream is empty, it can be a health check. Do not return error.
372            if err.kind() == ErrorKind::UnexpectedEof {
373                return Ok(FeMessage::HealthCheck);
374            } else {
375                return Err(err);
376            }
377        }
378
379        let len = stream.read_i32().await?;
380        let protocol_num = stream.read_i32().await?;
381        let payload_len = (len - 8) as usize;
382        if payload_len >= isize::MAX as usize {
383            return Err(std::io::Error::new(
384                ErrorKind::InvalidInput,
385                format!("Payload length has exceed usize::MAX {:?}", payload_len),
386            ));
387        }
388        let mut payload = vec![0; payload_len];
389        if payload_len > 0 {
390            stream.read_exact(&mut payload).await?;
391        }
392        match protocol_num {
393            // code from: https://www.postgresql.org/docs/current/protocol-message-formats.html
394            196608 => Ok(FeMessage::Startup(FeStartupMessage::build_with_payload(
395                &payload,
396            )?)),
397            80877104 => Ok(FeMessage::Gss),
398            80877103 => Ok(FeMessage::Ssl),
399            // Cancel request code.
400            80877102 => FeCancelMessage::parse(Bytes::from(payload)),
401            _ => Err(std::io::Error::new(
402                ErrorKind::InvalidInput,
403                format!(
404                    "Unsupported protocol number in start up msg {:?}",
405                    protocol_num
406                ),
407            )),
408        }
409    }
410}
411
412/// Continue read until reached a \0. Used in reading string from Bytes.
413fn read_null_terminated(buf: &mut Bytes) -> Result<Bytes> {
414    let mut result = BytesMut::new();
415
416    loop {
417        if !buf.has_remaining() {
418            panic!("no null-terminator in string");
419        }
420
421        let byte = buf.get_u8();
422
423        if byte == 0 {
424            break;
425        }
426        result.put_u8(byte);
427    }
428    Ok(result.freeze())
429}
430
431/// Message sent from server to psql client. Implement `write` (how to serialize it into psql
432/// buffer).
433/// Ref: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
434#[derive(Debug, Clone, Copy)]
435pub enum BeMessage<'a> {
436    AuthenticationOk,
437    AuthenticationCleartextPassword,
438    AuthenticationMd5Password(&'a [u8; 4]),
439    CommandComplete(BeCommandCompleteMessage),
440    NoticeResponse(&'a str),
441    // Single byte - used in response to SSLRequest/GSSENCRequest.
442    EncryptionResponseSsl,
443    EncryptionResponseGss,
444    EncryptionResponseNo,
445    EmptyQueryResponse,
446    ParseComplete,
447    BindComplete,
448    PortalSuspended,
449    // array of parameter oid(i32)
450    ParameterDescription(&'a [i32]),
451    NoData,
452    DataRow(&'a Row),
453    ParameterStatus(BeParameterStatusMessage<'a>),
454    ReadyForQuery(TransactionStatus),
455    RowDescription(&'a [PgFieldDescriptor]),
456    ErrorResponse {
457        error: &'a (dyn std::error::Error + Send + Sync + 'static),
458        pretty: bool,
459        severity: Option<Severity>,
460    },
461    CloseComplete,
462
463    // Copy
464    CopyOutResponse(usize),
465    CopyData(&'a Row),
466    CopyDone,
467
468    // 0: process ID, 1: secret key
469    BackendKeyData((i32, i32)),
470}
471
472#[derive(Debug, Copy, Clone)]
473pub enum BeParameterStatusMessage<'a> {
474    ClientEncoding(&'a str),
475    StandardConformingString(&'a str),
476    ServerVersion(&'a str),
477    ApplicationName(&'a str),
478    TimeZone(&'a str),
479}
480
481#[derive(Debug, Copy, Clone)]
482pub struct BeCommandCompleteMessage {
483    pub stmt_type: StatementType,
484    pub rows_cnt: i32,
485}
486
487#[derive(Debug, Clone, Copy)]
488pub enum TransactionStatus {
489    Idle,
490    InTransaction,
491    InFailedTransaction,
492}
493
494impl BeMessage<'_> {
495    /// Write message to the given buf.
496    pub fn write(buf: &mut BytesMut, message: BeMessage<'_>) -> Result<()> {
497        match message {
498            // AuthenticationOk
499            // +-----+----------+-----------+
500            // | 'R' | int32(8) | int32(0)  |
501            // +-----+----------+-----------+
502            BeMessage::AuthenticationOk => {
503                buf.put_u8(b'R');
504                buf.put_i32(8);
505                buf.put_i32(0);
506            }
507
508            // AuthenticationCleartextPassword
509            // +-----+----------+-----------+
510            // | 'R' | int32(8) | int32(3)  |
511            // +-----+----------+-----------+
512            BeMessage::AuthenticationCleartextPassword => {
513                buf.put_u8(b'R');
514                buf.put_i32(8);
515                buf.put_i32(3);
516            }
517
518            // AuthenticationMD5Password
519            // +-----+----------+-----------+----------------+
520            // | 'R' | int32(12) | int32(5)  |  Byte4(salt)  |
521            // +-----+----------+-----------+----------------+
522            //
523            // The 4-byte random salt will be used by client to send encrypted password as
524            // concat('md5', md5(concat(md5(concat(password, username)), random-salt))).
525            BeMessage::AuthenticationMd5Password(salt) => {
526                buf.put_u8(b'R');
527                buf.put_i32(12);
528                buf.put_i32(5);
529                buf.put_slice(&salt[..]);
530            }
531
532            // ParameterStatus
533            // +-----+-----------+----------+------+-----------+------+
534            // | 'S' | int32 len | str name | '\0' | str value | '\0' |
535            // +-----+-----------+----------+------+-----------+------+
536            BeMessage::ParameterStatus(param) => {
537                use BeParameterStatusMessage::*;
538                let [name, value] = match param {
539                    ClientEncoding(val) => [b"client_encoding", val.as_bytes()],
540                    StandardConformingString(val) => {
541                        [b"standard_conforming_strings", val.as_bytes()]
542                    }
543                    ServerVersion(val) => [b"server_version", val.as_bytes()],
544                    ApplicationName(val) => [b"application_name", val.as_bytes()],
545                    // psycopg3 is case-sensitive, so we use "TimeZone" instead of "timezone" #18079
546                    TimeZone(val) => [b"TimeZone", val.as_bytes()],
547                };
548
549                // Parameter names and values are passed as null-terminated strings
550                let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
551                let mut buffer = vec![];
552                let cnt = buffer.write_vectored(iov).unwrap();
553
554                buf.put_u8(b'S');
555                write_body(buf, |stream| {
556                    stream.put_slice(&buffer[..cnt]);
557                    Ok(())
558                })
559                .unwrap();
560            }
561
562            // CommandComplete
563            // +-----+-----------+-----------------+
564            // | 'C' | int32 len | str commandTag  |
565            // +-----+-----------+-----------------+
566            BeMessage::CommandComplete(cmd) => {
567                let rows_cnt = cmd.rows_cnt;
568                let mut stmt_type = cmd.stmt_type;
569                let mut tag = "".to_owned();
570                stmt_type = match stmt_type {
571                    StatementType::INSERT_RETURNING => StatementType::INSERT,
572                    StatementType::DELETE_RETURNING => StatementType::DELETE,
573                    StatementType::UPDATE_RETURNING => StatementType::UPDATE,
574                    s => s,
575                };
576                tag.push_str(&stmt_type.to_string());
577                if stmt_type == StatementType::INSERT {
578                    tag.push_str(" 0");
579                }
580                if stmt_type.is_command() {
581                    tag.push(' ');
582                    tag.push_str(&rows_cnt.to_string());
583                }
584                buf.put_u8(b'C');
585                write_body(buf, |buf| {
586                    write_cstr(buf, tag.as_bytes())?;
587                    Ok(())
588                })?;
589            }
590
591            // NoticeResponse
592            // +-----+-----------+------------------+------------------+
593            // | 'N' | int32 len | byte1 field type | str field value  |
594            // +-----+-----------+------------------+-+----------------+
595            // description of the fields can be found here:
596            // https://www.postgresql.org/docs/current/protocol-error-fields.html
597            BeMessage::NoticeResponse(notice) => {
598                buf.put_u8(b'N');
599                write_err_or_notice(buf, &ErrorOrNoticeMessage::notice(notice))?;
600            }
601
602            // DataRow
603            // +-----+-----------+--------------+--------+-----+--------+
604            // | 'D' | int32 len | int16 colNum | column | ... | column |
605            // +-----+-----------+--------------+----+---+-----+--------+
606            //                                       |
607            //                          +-----------+v------+
608            //                          | int32 len | bytes |
609            //                          +-----------+-------+
610            BeMessage::DataRow(vals) => {
611                buf.put_u8(b'D');
612                write_body(buf, |buf| {
613                    buf.put_u16(vals.len() as u16); // num of cols
614                    for val_opt in vals.values() {
615                        if let Some(val) = val_opt {
616                            buf.put_u32(val.len() as u32);
617                            buf.put_slice(val);
618                        } else {
619                            buf.put_i32(-1);
620                        }
621                    }
622                    Ok(())
623                })
624                .unwrap();
625            }
626
627            // RowDescription
628            // +-----+-----------+--------------+-------+-----+-------+
629            // | 'T' | int32 len | int16 colNum | field | ... | field |
630            // +-----+-----------+--------------+----+--+-----+-------+
631            //                                       |
632            // +---------------+-------+-------+-----v-+-------+-------+-------+
633            // | str fieldName | int32 | int16 | int32 | int16 | int32 | int16 |
634            // +---------------+---+---+---+---+---+---+----+--+---+---+---+---+
635            //                     |       |       |        |      |       |
636            //                     v       |       v        v      |       v
637            //                tableOID     |    typeOID  typeLen   |   formatCode
638            //                             v                       v
639            //                        colAttrNum               typeModifier
640            BeMessage::RowDescription(row_descs) => {
641                buf.put_u8(b'T');
642                write_body(buf, |buf| {
643                    buf.put_i16(row_descs.len() as i16); // # of fields
644                    for pg_field in row_descs {
645                        write_cstr(buf, pg_field.get_name().as_bytes())?;
646                        buf.put_i32(pg_field.get_table_oid()); // table oid
647                        buf.put_i16(pg_field.get_col_attr_num()); // attnum
648                        buf.put_i32(pg_field.get_type_oid());
649                        buf.put_i16(pg_field.get_type_len());
650                        buf.put_i32(pg_field.get_type_modifier()); // typmod
651                        buf.put_i16(pg_field.get_format_code()); // format code
652                    }
653                    Ok(())
654                })?;
655            }
656            // ReadyForQuery
657            // +-----+----------+---------------------------+
658            // | 'Z' | int32(5) | byte1(transaction status) |
659            // +-----+----------+---------------------------+
660            BeMessage::ReadyForQuery(txn_status) => {
661                buf.put_u8(b'Z');
662                buf.put_i32(5);
663                // TODO: add transaction status
664                buf.put_u8(match txn_status {
665                    TransactionStatus::Idle => b'I',
666                    TransactionStatus::InTransaction => b'T',
667                    TransactionStatus::InFailedTransaction => b'E',
668                });
669            }
670
671            BeMessage::ParseComplete => {
672                buf.put_u8(b'1');
673                write_body(buf, |_| Ok(()))?;
674            }
675
676            BeMessage::BindComplete => {
677                buf.put_u8(b'2');
678                write_body(buf, |_| Ok(()))?;
679            }
680
681            BeMessage::CloseComplete => {
682                buf.put_u8(b'3');
683                write_body(buf, |_| Ok(()))?;
684            }
685
686            BeMessage::PortalSuspended => {
687                buf.put_u8(b's');
688                write_body(buf, |_| Ok(()))?;
689            }
690            // ParameterDescription
691            // +-----+-----------+--------------------+---------------+-----+---------------+
692            // | 't' | int32 len | int16 ParameterNum | int32 typeOID | ... | int32 typeOID |
693            // +-----+-----------+-----------------+--+---------------+-----+---------------+
694            BeMessage::ParameterDescription(para_descs) => {
695                buf.put_u8(b't');
696                write_body(buf, |buf| {
697                    buf.put_i16(para_descs.len() as i16);
698                    for oid in para_descs {
699                        buf.put_i32(*oid);
700                    }
701                    Ok(())
702                })?;
703            }
704
705            BeMessage::NoData => {
706                buf.put_u8(b'n');
707                write_body(buf, |_| Ok(())).unwrap();
708            }
709
710            BeMessage::EncryptionResponseSsl => {
711                buf.put_u8(b'S');
712            }
713
714            BeMessage::EncryptionResponseGss => {
715                buf.put_u8(b'G');
716            }
717
718            BeMessage::EncryptionResponseNo => {
719                buf.put_u8(b'N');
720            }
721
722            // EmptyQueryResponse
723            // +-----+----------+
724            // | 'I' | int32(4) |
725            // +-----+----------+
726            BeMessage::EmptyQueryResponse => {
727                buf.put_u8(b'I');
728                buf.put_i32(4);
729            }
730
731            BeMessage::ErrorResponse {
732                error,
733                pretty,
734                severity,
735            } => {
736                // 'E' signalizes ErrorResponse messages
737                buf.put_u8(b'E');
738                // Format the error as a pretty report.
739                let error_message = match severity {
740                    Some(severity) => {
741                        ErrorOrNoticeMessage::error_with_severity(error, pretty, severity)
742                    }
743                    None => ErrorOrNoticeMessage::error(error, pretty),
744                };
745                write_err_or_notice(buf, &error_message)?;
746            }
747
748            BeMessage::BackendKeyData((process_id, secret_key)) => {
749                buf.put_u8(b'K');
750                write_body(buf, |buf| {
751                    buf.put_i32(process_id);
752                    buf.put_i32(secret_key);
753                    Ok(())
754                })?;
755            }
756            BeMessage::CopyOutResponse(col_num) => {
757                buf.put_u8(b'H');
758                write_body(buf, |buf| {
759                    buf.put_i8(Format::Text.to_i8());
760                    buf.put_i16(col_num as _);
761                    for _ in 0..col_num {
762                        buf.put_i16(Format::Text.to_i8() as _);
763                    }
764                    Ok(())
765                })?;
766            }
767            BeMessage::CopyData(row) => {
768                buf.put_u8(b'd');
769                // As in https://www.postgresql.org/docs/current/sql-copy.html, the default format is TSV format
770                write_body(buf, |buf| {
771                    fn write_str_bytes(
772                        buf: &mut BytesMut,
773                        str_bytes: &Option<Bytes>,
774                    ) -> Result<()> {
775                        let Some(str_bytes) = str_bytes else {
776                            return Ok(());
777                        };
778                        let s = String::from_utf8_lossy(str_bytes);
779                        for c in s.as_str().chars() {
780                            // As suggested in https://en.wikipedia.org/wiki/Tab-separated_values
781                            // we only escape "\t\b\r\\"
782                            match c {
783                                '\t' => {
784                                    buf.put_slice(b"\\t");
785                                }
786                                '\n' => {
787                                    buf.put_slice(b"\\n");
788                                }
789                                '\r' => {
790                                    buf.put_slice(b"\\r");
791                                }
792                                '\\' => {
793                                    buf.put_slice(b"\\\\");
794                                }
795                                _ => {
796                                    std::fmt::Write::write_char(buf, c).map_err(|_| {
797                                        Error::other(anyhow!("failed to write_char [{c}]"))
798                                    })?;
799                                }
800                            }
801                        }
802                        Ok(())
803                    }
804                    match row.values() {
805                        [] => {}
806                        [first, rest @ ..] => {
807                            write_str_bytes(buf, first)?;
808
809                            for rest in rest {
810                                buf.put_u8(b'\t');
811                                write_str_bytes(buf, rest)?;
812                            }
813                        }
814                    }
815                    buf.put_u8(b'\n');
816                    Ok(())
817                })?;
818            }
819            BeMessage::CopyDone => {
820                buf.put_u8(b'c');
821                write_body(buf, |_| Ok(()))?;
822            }
823        }
824
825        Ok(())
826    }
827}
828
829// Safe usize -> i32|i16 conversion, from rust-postgres
830trait FromUsize: Sized {
831    fn from_usize(x: usize) -> Result<Self>;
832}
833
834macro_rules! from_usize {
835    ($t:ty) => {
836        impl FromUsize for $t {
837            #[inline]
838            fn from_usize(x: usize) -> Result<$t> {
839                if x > <$t>::MAX as usize {
840                    Err(Error::new(ErrorKind::InvalidInput, "value too large to transmit").into())
841                } else {
842                    Ok(x as $t)
843                }
844            }
845        }
846    };
847}
848
849from_usize!(i32);
850
851/// Call f() to write body of the message and prepend it with 4-byte len as
852/// prescribed by the protocol. First write out body value and fill length value as i32 in front of
853/// it.
854fn write_body<F>(buf: &mut BytesMut, f: F) -> Result<()>
855where
856    F: FnOnce(&mut BytesMut) -> Result<()>,
857{
858    let base = buf.len();
859    buf.extend_from_slice(&[0; 4]);
860
861    f(buf)?;
862
863    let size = i32::from_usize(buf.len() - base)?;
864    BigEndian::write_i32(&mut buf[base..], size);
865    Ok(())
866}
867
868/// Safe write of s into buf as cstring (String in the protocol).
869fn write_cstr(buf: &mut BytesMut, s: &[u8]) -> Result<()> {
870    if s.contains(&0) {
871        return Err(Error::new(
872            ErrorKind::InvalidInput,
873            "string contains embedded null",
874        ));
875    }
876    buf.put_slice(s);
877    buf.put_u8(0);
878    Ok(())
879}
880
881/// Safe write error or notice message.
882fn write_err_or_notice(buf: &mut BytesMut, msg: &ErrorOrNoticeMessage<'_>) -> Result<()> {
883    write_body(buf, |buf| {
884        buf.put_u8(b'S'); // severity
885        write_cstr(buf, msg.severity.as_str().as_bytes())?;
886
887        buf.put_u8(b'C'); // SQLSTATE error code
888        write_cstr(buf, msg.error_code.sqlstate().as_bytes())?;
889
890        buf.put_u8(b'M'); // the message
891        write_cstr(buf, msg.message.as_bytes())?;
892
893        buf.put_u8(0); // terminator
894        Ok(())
895    })
896}
897
898#[cfg(test)]
899mod tests {
900    use bytes::Bytes;
901
902    use crate::pg_message::{FeParseMessage, FeQueryMessage, FeStartupMessage};
903
904    #[test]
905    fn test_get_sql() {
906        let fe = FeQueryMessage {
907            sql_bytes: Bytes::from(vec![255, 255, 255, 255, 255, 255, 0]),
908        };
909        assert!(fe.get_sql().is_err(), "{}", true);
910        let fe = FeQueryMessage {
911            sql_bytes: Bytes::from(vec![1, 2, 3, 4, 5, 6, 7, 8]),
912        };
913        assert!(fe.get_sql().is_err(), "{}", true);
914
915        let fe = FeParseMessage {
916            statement_name: Bytes::from_static(b"stmt\0"),
917            sql_bytes: Bytes::from_static(b"select 1\0"),
918            type_ids: vec![],
919        };
920        assert_eq!(fe.get_sql().unwrap(), "select 1");
921    }
922
923    #[test]
924    fn test_startup_build() {
925        let payload = b"user\0dev\0options\0\0\0";
926        let msg = FeStartupMessage::build_with_payload(payload).unwrap();
927        assert_eq!(msg.config.get("options").unwrap(), "");
928    }
929}