pgwire/
pg_message.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ffi::CStr;
17use std::io::{Error, ErrorKind, IoSlice, Result, Write};
18
19use anyhow::anyhow;
20use byteorder::{BigEndian, ByteOrder};
21/// Part of code learned from <https://github.com/zenithdb/zenith/blob/main/zenith_utils/src/pq_proto.rs>.
22use bytes::{Buf, BufMut, Bytes, BytesMut};
23use peekable::tokio::AsyncPeekable;
24use tokio::io::{AsyncRead, AsyncReadExt};
25
26use crate::error_or_notice::ErrorOrNoticeMessage;
27use crate::pg_field_descriptor::PgFieldDescriptor;
28use crate::pg_response::StatementType;
29use crate::types::{Format, Row};
30
31/// Messages that can be sent from pg client to server. Implement `read`.
32#[derive(Debug)]
33pub enum FeMessage {
34    Ssl,
35    Gss,
36    Startup(FeStartupMessage),
37    Query(FeQueryMessage),
38    Parse(FeParseMessage),
39    Password(FePasswordMessage),
40    Describe(FeDescribeMessage),
41    Bind(FeBindMessage),
42    Execute(FeExecuteMessage),
43    Close(FeCloseMessage),
44    Sync,
45    CancelQuery(FeCancelMessage),
46    Terminate,
47    Flush,
48    // special msg to detect health check, which represents the client immediately closes the connection cleanly without sending any data.
49    HealthCheck,
50    // The original message has been rejected due to server throttling. This is a placeholder message generated by server.
51    ServerThrottle(ServerThrottleReason),
52}
53
54impl FeMessage {
55    pub fn get_sql(&self) -> Result<Option<&str>> {
56        match self {
57            FeMessage::Query(q) => Ok(Some(q.get_sql()?)),
58            _ => Ok(None),
59        }
60    }
61}
62
63#[derive(Debug)]
64pub enum ServerThrottleReason {
65    TooLargeMessage,
66    TooManyMemoryUsage,
67}
68
69#[derive(Debug)]
70pub struct FeStartupMessage {
71    pub config: HashMap<String, String>,
72}
73
74impl FeStartupMessage {
75    pub fn build_with_payload(payload: &[u8]) -> Result<Self> {
76        let config = match std::str::from_utf8(payload) {
77            Ok(v) => Ok(v.trim_end_matches('\0')),
78            Err(err) => Err(Error::new(
79                ErrorKind::InvalidInput,
80                anyhow!(err).context("Input end error"),
81            )),
82        }?;
83        let mut map = HashMap::new();
84        let config: Vec<&str> = config.split('\0').collect();
85        if config.len() % 2 == 1 {
86            return Err(Error::new(
87                ErrorKind::InvalidInput,
88                "Invalid input config: odd number of config pairs",
89            ));
90        }
91        config.chunks(2).for_each(|chunk| {
92            map.insert(chunk[0].to_owned(), chunk[1].to_owned());
93        });
94        Ok(FeStartupMessage { config: map })
95    }
96}
97
98/// Query message contains the string sql.
99#[derive(Debug)]
100pub struct FeQueryMessage {
101    pub sql_bytes: Bytes,
102}
103
104#[derive(Debug)]
105pub struct FeBindMessage {
106    pub param_format_codes: Vec<i16>,
107    pub result_format_codes: Vec<i16>,
108
109    pub params: Vec<Option<Bytes>>,
110    pub portal_name: Bytes,
111    pub statement_name: Bytes,
112}
113
114#[derive(Debug)]
115pub struct FeExecuteMessage {
116    pub portal_name: Bytes,
117    pub max_rows: i32,
118}
119
120#[derive(Debug)]
121pub struct FeParseMessage {
122    pub statement_name: Bytes,
123    pub sql_bytes: Bytes,
124    pub type_ids: Vec<i32>,
125}
126
127#[derive(Debug)]
128pub struct FePasswordMessage {
129    pub password: Bytes,
130}
131
132#[derive(Debug)]
133pub struct FeDescribeMessage {
134    // 'S' to describe a prepared statement; or 'P' to describe a portal.
135    pub kind: u8,
136    pub name: Bytes,
137}
138
139#[derive(Debug)]
140pub struct FeCloseMessage {
141    pub kind: u8,
142    pub name: Bytes,
143}
144
145#[derive(Debug)]
146pub struct FeCancelMessage {
147    pub target_process_id: i32,
148    pub target_secret_key: i32,
149}
150
151impl FeCancelMessage {
152    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
153        let target_process_id = buf.get_i32();
154        let target_secret_key = buf.get_i32();
155        Ok(FeMessage::CancelQuery(Self {
156            target_process_id,
157            target_secret_key,
158        }))
159    }
160}
161
162impl FeDescribeMessage {
163    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
164        let kind = buf.get_u8();
165        let name = read_null_terminated(&mut buf)?;
166
167        Ok(FeMessage::Describe(FeDescribeMessage { kind, name }))
168    }
169}
170
171impl FeBindMessage {
172    // Bind Message Header
173    // +-----+-----------+
174    // | 'B' | int32 len |
175    // +-----+-----------+
176    // Bind Message Body
177    // +----------------+---------------+
178    // | str portalname | str statement |
179    // +----------------+---------------+
180    // +---------------------+------------------+-------+
181    // | int16 numFormatCode | int16 FormatCode |  ...  |
182    // +---------------------+------------------+-------+
183    // +-----------------+-------------------+---------------+
184    // | int16 numParams | int32 valueLength |  byte value.. |
185    // +-----------------+-------------------+---------------+
186    // +----------------------------------+------------------+-------+
187    // | int16 numResultColumnFormatCodes | int16 FormatCode |  ...  |
188    // +----------------------------------+------------------+-------+
189    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
190        let portal_name = read_null_terminated(&mut buf)?;
191        let statement_name = read_null_terminated(&mut buf)?;
192
193        let len = buf.get_i16();
194        let param_format_codes = (0..len).map(|_| buf.get_i16()).collect();
195
196        // Read Params
197        let len = buf.get_i16();
198        let params = (0..len)
199            .map(|_| {
200                let val_len = buf.get_i32();
201                if val_len == -1 {
202                    None
203                } else {
204                    Some(buf.copy_to_bytes(val_len as usize))
205                }
206            })
207            .collect();
208
209        let len = buf.get_i16();
210        let result_format_codes = (0..len).map(|_| buf.get_i16()).collect();
211
212        Ok(FeMessage::Bind(FeBindMessage {
213            param_format_codes,
214            result_format_codes,
215            params,
216            portal_name,
217            statement_name,
218        }))
219    }
220}
221
222impl FeExecuteMessage {
223    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
224        let portal_name = read_null_terminated(&mut buf)?;
225        let max_rows = buf.get_i32();
226
227        Ok(FeMessage::Execute(FeExecuteMessage {
228            portal_name,
229            max_rows,
230        }))
231    }
232}
233
234impl FeParseMessage {
235    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
236        let statement_name = read_null_terminated(&mut buf)?;
237        let sql_bytes = read_null_terminated(&mut buf)?;
238        let nparams = buf.get_i16();
239
240        let type_ids: Vec<i32> = (0..nparams).map(|_| buf.get_i32()).collect();
241
242        Ok(FeMessage::Parse(FeParseMessage {
243            statement_name,
244            sql_bytes,
245            type_ids,
246        }))
247    }
248}
249
250impl FePasswordMessage {
251    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
252        let password = read_null_terminated(&mut buf)?;
253
254        Ok(FeMessage::Password(FePasswordMessage { password }))
255    }
256}
257
258impl FeQueryMessage {
259    pub fn get_sql(&self) -> Result<&str> {
260        match CStr::from_bytes_with_nul(&self.sql_bytes) {
261            Ok(cstr) => cstr.to_str().map_err(|err| {
262                Error::new(
263                    ErrorKind::InvalidInput,
264                    anyhow!(err).context("Invalid UTF-8 sequence"),
265                )
266            }),
267            Err(err) => Err(Error::new(
268                ErrorKind::InvalidInput,
269                anyhow!(err).context("Input end error"),
270            )),
271        }
272    }
273}
274
275impl FeCloseMessage {
276    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
277        let kind = buf.get_u8();
278        let name = read_null_terminated(&mut buf)?;
279        Ok(FeMessage::Close(FeCloseMessage { kind, name }))
280    }
281}
282
283#[derive(Clone)]
284pub struct FeMessageHeader {
285    pub tag: u8,
286    pub payload_len: i32,
287}
288
289impl FeMessage {
290    /// Read one message from the stream.
291    pub async fn read_header(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessageHeader> {
292        let tag = stream.read_u8().await?;
293        let len = stream.read_i32().await?;
294
295        let payload_len = len - 4;
296        Ok(FeMessageHeader { tag, payload_len })
297    }
298
299    /// Read one message from the stream.
300    pub async fn read_body(
301        stream: &mut (impl AsyncRead + Unpin),
302        header: FeMessageHeader,
303    ) -> Result<FeMessage> {
304        let FeMessageHeader { tag, payload_len } = header;
305        let mut payload: Vec<u8> = vec![0; payload_len as usize];
306        if payload_len > 0 {
307            stream.read_exact(&mut payload).await?;
308        }
309        let sql_bytes = Bytes::from(payload);
310        match tag {
311            b'Q' => Ok(FeMessage::Query(FeQueryMessage { sql_bytes })),
312            b'P' => FeParseMessage::parse(sql_bytes),
313            b'D' => FeDescribeMessage::parse(sql_bytes),
314            b'B' => FeBindMessage::parse(sql_bytes),
315            b'E' => FeExecuteMessage::parse(sql_bytes),
316            b'S' => Ok(FeMessage::Sync),
317            b'X' => Ok(FeMessage::Terminate),
318            b'C' => FeCloseMessage::parse(sql_bytes),
319            b'p' => FePasswordMessage::parse(sql_bytes),
320            b'H' => Ok(FeMessage::Flush),
321            _ => Err(std::io::Error::new(
322                ErrorKind::InvalidInput,
323                format!("Unsupported tag of regular message: {}", tag),
324            )),
325        }
326    }
327
328    pub async fn skip_body(
329        stream: &mut (impl AsyncRead + Unpin),
330        header: FeMessageHeader,
331    ) -> Result<()> {
332        let FeMessageHeader {
333            tag: _,
334            payload_len,
335        } = header;
336
337        if payload_len > 0 {
338            // Use smaller batches to process the payload instead of handling it all at once to minimize memory usage.
339            const BUF_SIZE: usize = 1024;
340            let mut buf: Vec<u8> = vec![0; BUF_SIZE];
341            for _ in 0..(payload_len as usize) / BUF_SIZE {
342                stream.read_exact(&mut buf).await?;
343            }
344            let remain = (payload_len as usize) % BUF_SIZE;
345            if remain > 0 {
346                buf.truncate(remain);
347                stream.read_exact(&mut buf).await?;
348            }
349        }
350        Ok(())
351    }
352}
353
354impl FeStartupMessage {
355    /// Read startup message from the stream.
356    pub async fn read(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessage> {
357        let mut stream = AsyncPeekable::new(stream);
358
359        if let Err(err) = stream.peek_exact(&mut [0; 1]).await {
360            // If the stream is empty, it can be a health check. Do not return error.
361            if err.kind() == ErrorKind::UnexpectedEof {
362                return Ok(FeMessage::HealthCheck);
363            } else {
364                return Err(err);
365            }
366        }
367
368        let len = stream.read_i32().await?;
369        let protocol_num = stream.read_i32().await?;
370        let payload_len = (len - 8) as usize;
371        if payload_len >= isize::MAX as usize {
372            return Err(std::io::Error::new(
373                ErrorKind::InvalidInput,
374                format!("Payload length has exceed usize::MAX {:?}", payload_len),
375            ));
376        }
377        let mut payload = vec![0; payload_len];
378        if payload_len > 0 {
379            stream.read_exact(&mut payload).await?;
380        }
381        match protocol_num {
382            // code from: https://www.postgresql.org/docs/current/protocol-message-formats.html
383            196608 => Ok(FeMessage::Startup(FeStartupMessage::build_with_payload(
384                &payload,
385            )?)),
386            80877104 => Ok(FeMessage::Gss),
387            80877103 => Ok(FeMessage::Ssl),
388            // Cancel request code.
389            80877102 => FeCancelMessage::parse(Bytes::from(payload)),
390            _ => Err(std::io::Error::new(
391                ErrorKind::InvalidInput,
392                format!(
393                    "Unsupported protocol number in start up msg {:?}",
394                    protocol_num
395                ),
396            )),
397        }
398    }
399}
400
401/// Continue read until reached a \0. Used in reading string from Bytes.
402fn read_null_terminated(buf: &mut Bytes) -> Result<Bytes> {
403    let mut result = BytesMut::new();
404
405    loop {
406        if !buf.has_remaining() {
407            panic!("no null-terminator in string");
408        }
409
410        let byte = buf.get_u8();
411
412        if byte == 0 {
413            break;
414        }
415        result.put_u8(byte);
416    }
417    Ok(result.freeze())
418}
419
420/// Message sent from server to psql client. Implement `write` (how to serialize it into psql
421/// buffer).
422/// Ref: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
423#[derive(Debug, Clone, Copy)]
424pub enum BeMessage<'a> {
425    AuthenticationOk,
426    AuthenticationCleartextPassword,
427    AuthenticationMd5Password(&'a [u8; 4]),
428    CommandComplete(BeCommandCompleteMessage),
429    NoticeResponse(&'a str),
430    // Single byte - used in response to SSLRequest/GSSENCRequest.
431    EncryptionResponseSsl,
432    EncryptionResponseGss,
433    EncryptionResponseNo,
434    EmptyQueryResponse,
435    ParseComplete,
436    BindComplete,
437    PortalSuspended,
438    // array of parameter oid(i32)
439    ParameterDescription(&'a [i32]),
440    NoData,
441    DataRow(&'a Row),
442    ParameterStatus(BeParameterStatusMessage<'a>),
443    ReadyForQuery(TransactionStatus),
444    RowDescription(&'a [PgFieldDescriptor]),
445    ErrorResponse {
446        error: &'a (dyn std::error::Error + Send + Sync + 'static),
447        pretty: bool,
448    },
449    CloseComplete,
450
451    // Copy
452    CopyOutResponse(usize),
453    CopyData(&'a Row),
454    CopyDone,
455
456    // 0: process ID, 1: secret key
457    BackendKeyData((i32, i32)),
458}
459
460#[derive(Debug, Copy, Clone)]
461pub enum BeParameterStatusMessage<'a> {
462    ClientEncoding(&'a str),
463    StandardConformingString(&'a str),
464    ServerVersion(&'a str),
465    ApplicationName(&'a str),
466    TimeZone(&'a str),
467}
468
469#[derive(Debug, Copy, Clone)]
470pub struct BeCommandCompleteMessage {
471    pub stmt_type: StatementType,
472    pub rows_cnt: i32,
473}
474
475#[derive(Debug, Clone, Copy)]
476pub enum TransactionStatus {
477    Idle,
478    InTransaction,
479    InFailedTransaction,
480}
481
482impl BeMessage<'_> {
483    /// Write message to the given buf.
484    pub fn write(buf: &mut BytesMut, message: BeMessage<'_>) -> Result<()> {
485        match message {
486            // AuthenticationOk
487            // +-----+----------+-----------+
488            // | 'R' | int32(8) | int32(0)  |
489            // +-----+----------+-----------+
490            BeMessage::AuthenticationOk => {
491                buf.put_u8(b'R');
492                buf.put_i32(8);
493                buf.put_i32(0);
494            }
495
496            // AuthenticationCleartextPassword
497            // +-----+----------+-----------+
498            // | 'R' | int32(8) | int32(3)  |
499            // +-----+----------+-----------+
500            BeMessage::AuthenticationCleartextPassword => {
501                buf.put_u8(b'R');
502                buf.put_i32(8);
503                buf.put_i32(3);
504            }
505
506            // AuthenticationMD5Password
507            // +-----+----------+-----------+----------------+
508            // | 'R' | int32(12) | int32(5)  |  Byte4(salt)  |
509            // +-----+----------+-----------+----------------+
510            //
511            // The 4-byte random salt will be used by client to send encrypted password as
512            // concat('md5', md5(concat(md5(concat(password, username)), random-salt))).
513            BeMessage::AuthenticationMd5Password(salt) => {
514                buf.put_u8(b'R');
515                buf.put_i32(12);
516                buf.put_i32(5);
517                buf.put_slice(&salt[..]);
518            }
519
520            // ParameterStatus
521            // +-----+-----------+----------+------+-----------+------+
522            // | 'S' | int32 len | str name | '\0' | str value | '\0' |
523            // +-----+-----------+----------+------+-----------+------+
524            BeMessage::ParameterStatus(param) => {
525                use BeParameterStatusMessage::*;
526                let [name, value] = match param {
527                    ClientEncoding(val) => [b"client_encoding", val.as_bytes()],
528                    StandardConformingString(val) => {
529                        [b"standard_conforming_strings", val.as_bytes()]
530                    }
531                    ServerVersion(val) => [b"server_version", val.as_bytes()],
532                    ApplicationName(val) => [b"application_name", val.as_bytes()],
533                    // psycopg3 is case-sensitive, so we use "TimeZone" instead of "timezone" #18079
534                    TimeZone(val) => [b"TimeZone", val.as_bytes()],
535                };
536
537                // Parameter names and values are passed as null-terminated strings
538                let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
539                let mut buffer = vec![];
540                let cnt = buffer.write_vectored(iov).unwrap();
541
542                buf.put_u8(b'S');
543                write_body(buf, |stream| {
544                    stream.put_slice(&buffer[..cnt]);
545                    Ok(())
546                })
547                .unwrap();
548            }
549
550            // CommandComplete
551            // +-----+-----------+-----------------+
552            // | 'C' | int32 len | str commandTag  |
553            // +-----+-----------+-----------------+
554            BeMessage::CommandComplete(cmd) => {
555                let rows_cnt = cmd.rows_cnt;
556                let mut stmt_type = cmd.stmt_type;
557                let mut tag = "".to_owned();
558                stmt_type = match stmt_type {
559                    StatementType::INSERT_RETURNING => StatementType::INSERT,
560                    StatementType::DELETE_RETURNING => StatementType::DELETE,
561                    StatementType::UPDATE_RETURNING => StatementType::UPDATE,
562                    s => s,
563                };
564                tag.push_str(&stmt_type.to_string());
565                if stmt_type == StatementType::INSERT {
566                    tag.push_str(" 0");
567                }
568                if stmt_type.is_command() {
569                    tag.push(' ');
570                    tag.push_str(&rows_cnt.to_string());
571                }
572                buf.put_u8(b'C');
573                write_body(buf, |buf| {
574                    write_cstr(buf, tag.as_bytes())?;
575                    Ok(())
576                })?;
577            }
578
579            // NoticeResponse
580            // +-----+-----------+------------------+------------------+
581            // | 'N' | int32 len | byte1 field type | str field value  |
582            // +-----+-----------+------------------+-+----------------+
583            // description of the fields can be found here:
584            // https://www.postgresql.org/docs/current/protocol-error-fields.html
585            BeMessage::NoticeResponse(notice) => {
586                buf.put_u8(b'N');
587                write_err_or_notice(buf, &ErrorOrNoticeMessage::notice(notice))?;
588            }
589
590            // DataRow
591            // +-----+-----------+--------------+--------+-----+--------+
592            // | 'D' | int32 len | int16 colNum | column | ... | column |
593            // +-----+-----------+--------------+----+---+-----+--------+
594            //                                       |
595            //                          +-----------+v------+
596            //                          | int32 len | bytes |
597            //                          +-----------+-------+
598            BeMessage::DataRow(vals) => {
599                buf.put_u8(b'D');
600                write_body(buf, |buf| {
601                    buf.put_u16(vals.len() as u16); // num of cols
602                    for val_opt in vals.values() {
603                        if let Some(val) = val_opt {
604                            buf.put_u32(val.len() as u32);
605                            buf.put_slice(val);
606                        } else {
607                            buf.put_i32(-1);
608                        }
609                    }
610                    Ok(())
611                })
612                .unwrap();
613            }
614
615            // RowDescription
616            // +-----+-----------+--------------+-------+-----+-------+
617            // | 'T' | int32 len | int16 colNum | field | ... | field |
618            // +-----+-----------+--------------+----+--+-----+-------+
619            //                                       |
620            // +---------------+-------+-------+-----v-+-------+-------+-------+
621            // | str fieldName | int32 | int16 | int32 | int16 | int32 | int16 |
622            // +---------------+---+---+---+---+---+---+----+--+---+---+---+---+
623            //                     |       |       |        |      |       |
624            //                     v       |       v        v      |       v
625            //                tableOID     |    typeOID  typeLen   |   formatCode
626            //                             v                       v
627            //                        colAttrNum               typeModifier
628            BeMessage::RowDescription(row_descs) => {
629                buf.put_u8(b'T');
630                write_body(buf, |buf| {
631                    buf.put_i16(row_descs.len() as i16); // # of fields
632                    for pg_field in row_descs {
633                        write_cstr(buf, pg_field.get_name().as_bytes())?;
634                        buf.put_i32(pg_field.get_table_oid()); // table oid
635                        buf.put_i16(pg_field.get_col_attr_num()); // attnum
636                        buf.put_i32(pg_field.get_type_oid());
637                        buf.put_i16(pg_field.get_type_len());
638                        buf.put_i32(pg_field.get_type_modifier()); // typmod
639                        buf.put_i16(pg_field.get_format_code()); // format code
640                    }
641                    Ok(())
642                })?;
643            }
644            // ReadyForQuery
645            // +-----+----------+---------------------------+
646            // | 'Z' | int32(5) | byte1(transaction status) |
647            // +-----+----------+---------------------------+
648            BeMessage::ReadyForQuery(txn_status) => {
649                buf.put_u8(b'Z');
650                buf.put_i32(5);
651                // TODO: add transaction status
652                buf.put_u8(match txn_status {
653                    TransactionStatus::Idle => b'I',
654                    TransactionStatus::InTransaction => b'T',
655                    TransactionStatus::InFailedTransaction => b'E',
656                });
657            }
658
659            BeMessage::ParseComplete => {
660                buf.put_u8(b'1');
661                write_body(buf, |_| Ok(()))?;
662            }
663
664            BeMessage::BindComplete => {
665                buf.put_u8(b'2');
666                write_body(buf, |_| Ok(()))?;
667            }
668
669            BeMessage::CloseComplete => {
670                buf.put_u8(b'3');
671                write_body(buf, |_| Ok(()))?;
672            }
673
674            BeMessage::PortalSuspended => {
675                buf.put_u8(b's');
676                write_body(buf, |_| Ok(()))?;
677            }
678            // ParameterDescription
679            // +-----+-----------+--------------------+---------------+-----+---------------+
680            // | 't' | int32 len | int16 ParameterNum | int32 typeOID | ... | int32 typeOID |
681            // +-----+-----------+-----------------+--+---------------+-----+---------------+
682            BeMessage::ParameterDescription(para_descs) => {
683                buf.put_u8(b't');
684                write_body(buf, |buf| {
685                    buf.put_i16(para_descs.len() as i16);
686                    for oid in para_descs {
687                        buf.put_i32(*oid);
688                    }
689                    Ok(())
690                })?;
691            }
692
693            BeMessage::NoData => {
694                buf.put_u8(b'n');
695                write_body(buf, |_| Ok(())).unwrap();
696            }
697
698            BeMessage::EncryptionResponseSsl => {
699                buf.put_u8(b'S');
700            }
701
702            BeMessage::EncryptionResponseGss => {
703                buf.put_u8(b'G');
704            }
705
706            BeMessage::EncryptionResponseNo => {
707                buf.put_u8(b'N');
708            }
709
710            // EmptyQueryResponse
711            // +-----+----------+
712            // | 'I' | int32(4) |
713            // +-----+----------+
714            BeMessage::EmptyQueryResponse => {
715                buf.put_u8(b'I');
716                buf.put_i32(4);
717            }
718
719            BeMessage::ErrorResponse { error, pretty } => {
720                // 'E' signalizes ErrorResponse messages
721                buf.put_u8(b'E');
722                // Format the error as a pretty report.
723                write_err_or_notice(buf, &ErrorOrNoticeMessage::error(error, pretty))?;
724            }
725
726            BeMessage::BackendKeyData((process_id, secret_key)) => {
727                buf.put_u8(b'K');
728                write_body(buf, |buf| {
729                    buf.put_i32(process_id);
730                    buf.put_i32(secret_key);
731                    Ok(())
732                })?;
733            }
734            BeMessage::CopyOutResponse(col_num) => {
735                buf.put_u8(b'H');
736                write_body(buf, |buf| {
737                    buf.put_i8(Format::Text.to_i8());
738                    buf.put_i16(col_num as _);
739                    for _ in 0..col_num {
740                        buf.put_i16(Format::Text.to_i8() as _);
741                    }
742                    Ok(())
743                })?;
744            }
745            BeMessage::CopyData(row) => {
746                buf.put_u8(b'd');
747                // As in https://www.postgresql.org/docs/current/sql-copy.html, the default format is TSV format
748                write_body(buf, |buf| {
749                    fn write_str_bytes(
750                        buf: &mut BytesMut,
751                        str_bytes: &Option<Bytes>,
752                    ) -> Result<()> {
753                        let Some(str_bytes) = str_bytes else {
754                            return Ok(());
755                        };
756                        let s = String::from_utf8_lossy(str_bytes);
757                        for c in s.as_str().chars() {
758                            // As suggested in https://en.wikipedia.org/wiki/Tab-separated_values
759                            // we only escape "\t\b\r\\"
760                            match c {
761                                '\t' => {
762                                    buf.put_slice(b"\\t");
763                                }
764                                '\n' => {
765                                    buf.put_slice(b"\\n");
766                                }
767                                '\r' => {
768                                    buf.put_slice(b"\\r");
769                                }
770                                '\\' => {
771                                    buf.put_slice(b"\\\\");
772                                }
773                                _ => {
774                                    std::fmt::Write::write_char(buf, c).map_err(|_| {
775                                        Error::other(anyhow!("failed to write_char [{c}]"))
776                                    })?;
777                                }
778                            }
779                        }
780                        Ok(())
781                    }
782                    match row.values() {
783                        [] => {}
784                        [first, rest @ ..] => {
785                            write_str_bytes(buf, first)?;
786
787                            for rest in rest {
788                                buf.put_u8(b'\t');
789                                write_str_bytes(buf, rest)?;
790                            }
791                        }
792                    }
793                    buf.put_u8(b'\n');
794                    Ok(())
795                })?;
796            }
797            BeMessage::CopyDone => {
798                buf.put_u8(b'c');
799                write_body(buf, |_| Ok(()))?;
800            }
801        }
802
803        Ok(())
804    }
805}
806
807// Safe usize -> i32|i16 conversion, from rust-postgres
808trait FromUsize: Sized {
809    fn from_usize(x: usize) -> Result<Self>;
810}
811
812macro_rules! from_usize {
813    ($t:ty) => {
814        impl FromUsize for $t {
815            #[inline]
816            fn from_usize(x: usize) -> Result<$t> {
817                if x > <$t>::MAX as usize {
818                    Err(Error::new(ErrorKind::InvalidInput, "value too large to transmit").into())
819                } else {
820                    Ok(x as $t)
821                }
822            }
823        }
824    };
825}
826
827from_usize!(i32);
828
829/// Call f() to write body of the message and prepend it with 4-byte len as
830/// prescribed by the protocol. First write out body value and fill length value as i32 in front of
831/// it.
832fn write_body<F>(buf: &mut BytesMut, f: F) -> Result<()>
833where
834    F: FnOnce(&mut BytesMut) -> Result<()>,
835{
836    let base = buf.len();
837    buf.extend_from_slice(&[0; 4]);
838
839    f(buf)?;
840
841    let size = i32::from_usize(buf.len() - base)?;
842    BigEndian::write_i32(&mut buf[base..], size);
843    Ok(())
844}
845
846/// Safe write of s into buf as cstring (String in the protocol).
847fn write_cstr(buf: &mut BytesMut, s: &[u8]) -> Result<()> {
848    if s.contains(&0) {
849        return Err(Error::new(
850            ErrorKind::InvalidInput,
851            "string contains embedded null",
852        ));
853    }
854    buf.put_slice(s);
855    buf.put_u8(0);
856    Ok(())
857}
858
859/// Safe write error or notice message.
860fn write_err_or_notice(buf: &mut BytesMut, msg: &ErrorOrNoticeMessage<'_>) -> Result<()> {
861    write_body(buf, |buf| {
862        buf.put_u8(b'S'); // severity
863        write_cstr(buf, msg.severity.as_str().as_bytes())?;
864
865        buf.put_u8(b'C'); // SQLSTATE error code
866        write_cstr(buf, msg.error_code.sqlstate().as_bytes())?;
867
868        buf.put_u8(b'M'); // the message
869        write_cstr(buf, msg.message.as_bytes())?;
870
871        buf.put_u8(0); // terminator
872        Ok(())
873    })
874}
875
876#[cfg(test)]
877mod tests {
878    use bytes::Bytes;
879
880    use crate::pg_message::FeQueryMessage;
881
882    #[test]
883    fn test_get_sql() {
884        let fe = FeQueryMessage {
885            sql_bytes: Bytes::from(vec![255, 255, 255, 255, 255, 255, 0]),
886        };
887        assert!(fe.get_sql().is_err(), "{}", true);
888        let fe = FeQueryMessage {
889            sql_bytes: Bytes::from(vec![1, 2, 3, 4, 5, 6, 7, 8]),
890        };
891        assert!(fe.get_sql().is_err(), "{}", true);
892    }
893}