pgwire/
pg_message.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ffi::CStr;
17use std::io::{Error, ErrorKind, IoSlice, Result, Write};
18
19use anyhow::anyhow;
20use byteorder::{BigEndian, ByteOrder};
21/// Part of code learned from <https://github.com/zenithdb/zenith/blob/main/zenith_utils/src/pq_proto.rs>.
22use bytes::{Buf, BufMut, Bytes, BytesMut};
23use peekable::tokio::AsyncPeekable;
24use tokio::io::{AsyncRead, AsyncReadExt};
25
26use crate::error_or_notice::{ErrorOrNoticeMessage, Severity};
27use crate::pg_field_descriptor::PgFieldDescriptor;
28use crate::pg_response::StatementType;
29use crate::types::{Format, Row};
30
31/// Messages that can be sent from pg client to server. Implement `read`.
32#[derive(Debug)]
33pub enum FeMessage {
34    Ssl,
35    Gss,
36    Startup(FeStartupMessage),
37    Query(FeQueryMessage),
38    Parse(FeParseMessage),
39    Password(FePasswordMessage),
40    Describe(FeDescribeMessage),
41    Bind(FeBindMessage),
42    Execute(FeExecuteMessage),
43    Close(FeCloseMessage),
44    Sync,
45    CancelQuery(FeCancelMessage),
46    Terminate,
47    Flush,
48    // special msg to detect health check, which represents the client immediately closes the connection cleanly without sending any data.
49    HealthCheck,
50    // The original message has been rejected due to server throttling. This is a placeholder message generated by server.
51    ServerThrottle(ServerThrottleReason),
52}
53
54impl FeMessage {
55    pub fn get_sql(&self) -> Result<Option<&str>> {
56        match self {
57            FeMessage::Query(q) => Ok(Some(q.get_sql()?)),
58            _ => Ok(None),
59        }
60    }
61}
62
63#[derive(Debug)]
64pub enum ServerThrottleReason {
65    TooLargeMessage,
66    TooManyMemoryUsage,
67}
68
69#[derive(Debug)]
70pub struct FeStartupMessage {
71    pub config: HashMap<String, String>,
72}
73
74impl FeStartupMessage {
75    pub fn build_with_payload(payload: &[u8]) -> Result<Self> {
76        let config = match std::str::from_utf8(payload) {
77            Ok(v) => Ok(v.trim_end_matches('\0')),
78            Err(err) => Err(Error::new(
79                ErrorKind::InvalidInput,
80                anyhow!(err).context("Input end error"),
81            )),
82        }?;
83        let mut map = HashMap::new();
84        let config: Vec<&str> = config.split('\0').collect();
85        if config.len() % 2 == 1 {
86            return Err(Error::new(
87                ErrorKind::InvalidInput,
88                "Invalid input config: odd number of config pairs",
89            ));
90        }
91        config.chunks(2).for_each(|chunk| {
92            map.insert(chunk[0].to_owned(), chunk[1].to_owned());
93        });
94        Ok(FeStartupMessage { config: map })
95    }
96}
97
98/// Query message contains the string sql.
99#[derive(Debug)]
100pub struct FeQueryMessage {
101    pub sql_bytes: Bytes,
102}
103
104#[derive(Debug)]
105pub struct FeBindMessage {
106    pub param_format_codes: Vec<i16>,
107    pub result_format_codes: Vec<i16>,
108
109    pub params: Vec<Option<Bytes>>,
110    pub portal_name: Bytes,
111    pub statement_name: Bytes,
112}
113
114#[derive(Debug)]
115pub struct FeExecuteMessage {
116    pub portal_name: Bytes,
117    pub max_rows: i32,
118}
119
120#[derive(Debug)]
121pub struct FeParseMessage {
122    pub statement_name: Bytes,
123    pub sql_bytes: Bytes,
124    pub type_ids: Vec<i32>,
125}
126
127#[derive(Debug)]
128pub struct FePasswordMessage {
129    pub password: Bytes,
130}
131
132#[derive(Debug)]
133pub struct FeDescribeMessage {
134    // 'S' to describe a prepared statement; or 'P' to describe a portal.
135    pub kind: u8,
136    pub name: Bytes,
137}
138
139#[derive(Debug)]
140pub struct FeCloseMessage {
141    pub kind: u8,
142    pub name: Bytes,
143}
144
145#[derive(Debug)]
146pub struct FeCancelMessage {
147    pub target_process_id: i32,
148    pub target_secret_key: i32,
149}
150
151impl FeCancelMessage {
152    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
153        let target_process_id = buf.get_i32();
154        let target_secret_key = buf.get_i32();
155        Ok(FeMessage::CancelQuery(Self {
156            target_process_id,
157            target_secret_key,
158        }))
159    }
160}
161
162impl FeDescribeMessage {
163    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
164        let kind = buf.get_u8();
165        let name = read_null_terminated(&mut buf)?;
166
167        Ok(FeMessage::Describe(FeDescribeMessage { kind, name }))
168    }
169}
170
171impl FeBindMessage {
172    // Bind Message Header
173    // +-----+-----------+
174    // | 'B' | int32 len |
175    // +-----+-----------+
176    // Bind Message Body
177    // +----------------+---------------+
178    // | str portalname | str statement |
179    // +----------------+---------------+
180    // +---------------------+------------------+-------+
181    // | int16 numFormatCode | int16 FormatCode |  ...  |
182    // +---------------------+------------------+-------+
183    // +-----------------+-------------------+---------------+
184    // | int16 numParams | int32 valueLength |  byte value.. |
185    // +-----------------+-------------------+---------------+
186    // +----------------------------------+------------------+-------+
187    // | int16 numResultColumnFormatCodes | int16 FormatCode |  ...  |
188    // +----------------------------------+------------------+-------+
189    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
190        let portal_name = read_null_terminated(&mut buf)?;
191        let statement_name = read_null_terminated(&mut buf)?;
192
193        let len = buf.get_i16();
194        let param_format_codes = (0..len).map(|_| buf.get_i16()).collect();
195
196        // Read Params
197        let len = buf.get_i16();
198        let params = (0..len)
199            .map(|_| {
200                let val_len = buf.get_i32();
201                if val_len == -1 {
202                    None
203                } else {
204                    Some(buf.copy_to_bytes(val_len as usize))
205                }
206            })
207            .collect();
208
209        let len = buf.get_i16();
210        let result_format_codes = (0..len).map(|_| buf.get_i16()).collect();
211
212        Ok(FeMessage::Bind(FeBindMessage {
213            param_format_codes,
214            result_format_codes,
215            params,
216            portal_name,
217            statement_name,
218        }))
219    }
220}
221
222impl FeExecuteMessage {
223    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
224        let portal_name = read_null_terminated(&mut buf)?;
225        let max_rows = buf.get_i32();
226
227        Ok(FeMessage::Execute(FeExecuteMessage {
228            portal_name,
229            max_rows,
230        }))
231    }
232}
233
234impl FeParseMessage {
235    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
236        let statement_name = read_null_terminated(&mut buf)?;
237        let sql_bytes = read_null_terminated(&mut buf)?;
238        let nparams = buf.get_i16();
239
240        let type_ids: Vec<i32> = (0..nparams).map(|_| buf.get_i32()).collect();
241
242        Ok(FeMessage::Parse(FeParseMessage {
243            statement_name,
244            sql_bytes,
245            type_ids,
246        }))
247    }
248}
249
250impl FePasswordMessage {
251    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
252        let password = read_null_terminated(&mut buf)?;
253
254        Ok(FeMessage::Password(FePasswordMessage { password }))
255    }
256}
257
258impl FeQueryMessage {
259    pub fn get_sql(&self) -> Result<&str> {
260        match CStr::from_bytes_with_nul(&self.sql_bytes) {
261            Ok(cstr) => cstr.to_str().map_err(|err| {
262                Error::new(
263                    ErrorKind::InvalidInput,
264                    anyhow!(err).context("Invalid UTF-8 sequence"),
265                )
266            }),
267            Err(err) => Err(Error::new(
268                ErrorKind::InvalidInput,
269                anyhow!(err).context("Input end error"),
270            )),
271        }
272    }
273}
274
275impl FeCloseMessage {
276    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
277        let kind = buf.get_u8();
278        let name = read_null_terminated(&mut buf)?;
279        Ok(FeMessage::Close(FeCloseMessage { kind, name }))
280    }
281}
282
283#[derive(Clone)]
284pub struct FeMessageHeader {
285    pub tag: u8,
286    pub payload_len: i32,
287}
288
289impl FeMessage {
290    /// Read one message from the stream.
291    pub async fn read_header(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessageHeader> {
292        let tag = stream.read_u8().await?;
293        let len = stream.read_i32().await?;
294
295        let payload_len = len - 4;
296        Ok(FeMessageHeader { tag, payload_len })
297    }
298
299    /// Read one message from the stream.
300    pub async fn read_body(
301        stream: &mut (impl AsyncRead + Unpin),
302        header: FeMessageHeader,
303    ) -> Result<FeMessage> {
304        let FeMessageHeader { tag, payload_len } = header;
305        let mut payload: Vec<u8> = vec![0; payload_len as usize];
306        if payload_len > 0 {
307            stream.read_exact(&mut payload).await?;
308        }
309        let sql_bytes = Bytes::from(payload);
310        match tag {
311            b'Q' => Ok(FeMessage::Query(FeQueryMessage { sql_bytes })),
312            b'P' => FeParseMessage::parse(sql_bytes),
313            b'D' => FeDescribeMessage::parse(sql_bytes),
314            b'B' => FeBindMessage::parse(sql_bytes),
315            b'E' => FeExecuteMessage::parse(sql_bytes),
316            b'S' => Ok(FeMessage::Sync),
317            b'X' => Ok(FeMessage::Terminate),
318            b'C' => FeCloseMessage::parse(sql_bytes),
319            b'p' => FePasswordMessage::parse(sql_bytes),
320            b'H' => Ok(FeMessage::Flush),
321            _ => Err(std::io::Error::new(
322                ErrorKind::InvalidInput,
323                format!("Unsupported tag of regular message: {}", tag),
324            )),
325        }
326    }
327
328    pub async fn skip_body(
329        stream: &mut (impl AsyncRead + Unpin),
330        header: FeMessageHeader,
331    ) -> Result<()> {
332        let FeMessageHeader {
333            tag: _,
334            payload_len,
335        } = header;
336
337        if payload_len > 0 {
338            // Use smaller batches to process the payload instead of handling it all at once to minimize memory usage.
339            const BUF_SIZE: usize = 1024;
340            let mut buf: Vec<u8> = vec![0; BUF_SIZE];
341            for _ in 0..(payload_len as usize) / BUF_SIZE {
342                stream.read_exact(&mut buf).await?;
343            }
344            let remain = (payload_len as usize) % BUF_SIZE;
345            if remain > 0 {
346                buf.truncate(remain);
347                stream.read_exact(&mut buf).await?;
348            }
349        }
350        Ok(())
351    }
352}
353
354impl FeStartupMessage {
355    /// Read startup message from the stream.
356    pub async fn read(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessage> {
357        let mut stream = AsyncPeekable::new(stream);
358
359        if let Err(err) = stream.peek_exact(&mut [0; 1]).await {
360            // If the stream is empty, it can be a health check. Do not return error.
361            if err.kind() == ErrorKind::UnexpectedEof {
362                return Ok(FeMessage::HealthCheck);
363            } else {
364                return Err(err);
365            }
366        }
367
368        let len = stream.read_i32().await?;
369        let protocol_num = stream.read_i32().await?;
370        let payload_len = (len - 8) as usize;
371        if payload_len >= isize::MAX as usize {
372            return Err(std::io::Error::new(
373                ErrorKind::InvalidInput,
374                format!("Payload length has exceed usize::MAX {:?}", payload_len),
375            ));
376        }
377        let mut payload = vec![0; payload_len];
378        if payload_len > 0 {
379            stream.read_exact(&mut payload).await?;
380        }
381        match protocol_num {
382            // code from: https://www.postgresql.org/docs/current/protocol-message-formats.html
383            196608 => Ok(FeMessage::Startup(FeStartupMessage::build_with_payload(
384                &payload,
385            )?)),
386            80877104 => Ok(FeMessage::Gss),
387            80877103 => Ok(FeMessage::Ssl),
388            // Cancel request code.
389            80877102 => FeCancelMessage::parse(Bytes::from(payload)),
390            _ => Err(std::io::Error::new(
391                ErrorKind::InvalidInput,
392                format!(
393                    "Unsupported protocol number in start up msg {:?}",
394                    protocol_num
395                ),
396            )),
397        }
398    }
399}
400
401/// Continue read until reached a \0. Used in reading string from Bytes.
402fn read_null_terminated(buf: &mut Bytes) -> Result<Bytes> {
403    let mut result = BytesMut::new();
404
405    loop {
406        if !buf.has_remaining() {
407            panic!("no null-terminator in string");
408        }
409
410        let byte = buf.get_u8();
411
412        if byte == 0 {
413            break;
414        }
415        result.put_u8(byte);
416    }
417    Ok(result.freeze())
418}
419
420/// Message sent from server to psql client. Implement `write` (how to serialize it into psql
421/// buffer).
422/// Ref: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
423#[derive(Debug, Clone, Copy)]
424pub enum BeMessage<'a> {
425    AuthenticationOk,
426    AuthenticationCleartextPassword,
427    AuthenticationMd5Password(&'a [u8; 4]),
428    CommandComplete(BeCommandCompleteMessage),
429    NoticeResponse(&'a str),
430    // Single byte - used in response to SSLRequest/GSSENCRequest.
431    EncryptionResponseSsl,
432    EncryptionResponseGss,
433    EncryptionResponseNo,
434    EmptyQueryResponse,
435    ParseComplete,
436    BindComplete,
437    PortalSuspended,
438    // array of parameter oid(i32)
439    ParameterDescription(&'a [i32]),
440    NoData,
441    DataRow(&'a Row),
442    ParameterStatus(BeParameterStatusMessage<'a>),
443    ReadyForQuery(TransactionStatus),
444    RowDescription(&'a [PgFieldDescriptor]),
445    ErrorResponse {
446        error: &'a (dyn std::error::Error + Send + Sync + 'static),
447        pretty: bool,
448        severity: Option<Severity>,
449    },
450    CloseComplete,
451
452    // Copy
453    CopyOutResponse(usize),
454    CopyData(&'a Row),
455    CopyDone,
456
457    // 0: process ID, 1: secret key
458    BackendKeyData((i32, i32)),
459}
460
461#[derive(Debug, Copy, Clone)]
462pub enum BeParameterStatusMessage<'a> {
463    ClientEncoding(&'a str),
464    StandardConformingString(&'a str),
465    ServerVersion(&'a str),
466    ApplicationName(&'a str),
467    TimeZone(&'a str),
468}
469
470#[derive(Debug, Copy, Clone)]
471pub struct BeCommandCompleteMessage {
472    pub stmt_type: StatementType,
473    pub rows_cnt: i32,
474}
475
476#[derive(Debug, Clone, Copy)]
477pub enum TransactionStatus {
478    Idle,
479    InTransaction,
480    InFailedTransaction,
481}
482
483impl BeMessage<'_> {
484    /// Write message to the given buf.
485    pub fn write(buf: &mut BytesMut, message: BeMessage<'_>) -> Result<()> {
486        match message {
487            // AuthenticationOk
488            // +-----+----------+-----------+
489            // | 'R' | int32(8) | int32(0)  |
490            // +-----+----------+-----------+
491            BeMessage::AuthenticationOk => {
492                buf.put_u8(b'R');
493                buf.put_i32(8);
494                buf.put_i32(0);
495            }
496
497            // AuthenticationCleartextPassword
498            // +-----+----------+-----------+
499            // | 'R' | int32(8) | int32(3)  |
500            // +-----+----------+-----------+
501            BeMessage::AuthenticationCleartextPassword => {
502                buf.put_u8(b'R');
503                buf.put_i32(8);
504                buf.put_i32(3);
505            }
506
507            // AuthenticationMD5Password
508            // +-----+----------+-----------+----------------+
509            // | 'R' | int32(12) | int32(5)  |  Byte4(salt)  |
510            // +-----+----------+-----------+----------------+
511            //
512            // The 4-byte random salt will be used by client to send encrypted password as
513            // concat('md5', md5(concat(md5(concat(password, username)), random-salt))).
514            BeMessage::AuthenticationMd5Password(salt) => {
515                buf.put_u8(b'R');
516                buf.put_i32(12);
517                buf.put_i32(5);
518                buf.put_slice(&salt[..]);
519            }
520
521            // ParameterStatus
522            // +-----+-----------+----------+------+-----------+------+
523            // | 'S' | int32 len | str name | '\0' | str value | '\0' |
524            // +-----+-----------+----------+------+-----------+------+
525            BeMessage::ParameterStatus(param) => {
526                use BeParameterStatusMessage::*;
527                let [name, value] = match param {
528                    ClientEncoding(val) => [b"client_encoding", val.as_bytes()],
529                    StandardConformingString(val) => {
530                        [b"standard_conforming_strings", val.as_bytes()]
531                    }
532                    ServerVersion(val) => [b"server_version", val.as_bytes()],
533                    ApplicationName(val) => [b"application_name", val.as_bytes()],
534                    // psycopg3 is case-sensitive, so we use "TimeZone" instead of "timezone" #18079
535                    TimeZone(val) => [b"TimeZone", val.as_bytes()],
536                };
537
538                // Parameter names and values are passed as null-terminated strings
539                let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
540                let mut buffer = vec![];
541                let cnt = buffer.write_vectored(iov).unwrap();
542
543                buf.put_u8(b'S');
544                write_body(buf, |stream| {
545                    stream.put_slice(&buffer[..cnt]);
546                    Ok(())
547                })
548                .unwrap();
549            }
550
551            // CommandComplete
552            // +-----+-----------+-----------------+
553            // | 'C' | int32 len | str commandTag  |
554            // +-----+-----------+-----------------+
555            BeMessage::CommandComplete(cmd) => {
556                let rows_cnt = cmd.rows_cnt;
557                let mut stmt_type = cmd.stmt_type;
558                let mut tag = "".to_owned();
559                stmt_type = match stmt_type {
560                    StatementType::INSERT_RETURNING => StatementType::INSERT,
561                    StatementType::DELETE_RETURNING => StatementType::DELETE,
562                    StatementType::UPDATE_RETURNING => StatementType::UPDATE,
563                    s => s,
564                };
565                tag.push_str(&stmt_type.to_string());
566                if stmt_type == StatementType::INSERT {
567                    tag.push_str(" 0");
568                }
569                if stmt_type.is_command() {
570                    tag.push(' ');
571                    tag.push_str(&rows_cnt.to_string());
572                }
573                buf.put_u8(b'C');
574                write_body(buf, |buf| {
575                    write_cstr(buf, tag.as_bytes())?;
576                    Ok(())
577                })?;
578            }
579
580            // NoticeResponse
581            // +-----+-----------+------------------+------------------+
582            // | 'N' | int32 len | byte1 field type | str field value  |
583            // +-----+-----------+------------------+-+----------------+
584            // description of the fields can be found here:
585            // https://www.postgresql.org/docs/current/protocol-error-fields.html
586            BeMessage::NoticeResponse(notice) => {
587                buf.put_u8(b'N');
588                write_err_or_notice(buf, &ErrorOrNoticeMessage::notice(notice))?;
589            }
590
591            // DataRow
592            // +-----+-----------+--------------+--------+-----+--------+
593            // | 'D' | int32 len | int16 colNum | column | ... | column |
594            // +-----+-----------+--------------+----+---+-----+--------+
595            //                                       |
596            //                          +-----------+v------+
597            //                          | int32 len | bytes |
598            //                          +-----------+-------+
599            BeMessage::DataRow(vals) => {
600                buf.put_u8(b'D');
601                write_body(buf, |buf| {
602                    buf.put_u16(vals.len() as u16); // num of cols
603                    for val_opt in vals.values() {
604                        if let Some(val) = val_opt {
605                            buf.put_u32(val.len() as u32);
606                            buf.put_slice(val);
607                        } else {
608                            buf.put_i32(-1);
609                        }
610                    }
611                    Ok(())
612                })
613                .unwrap();
614            }
615
616            // RowDescription
617            // +-----+-----------+--------------+-------+-----+-------+
618            // | 'T' | int32 len | int16 colNum | field | ... | field |
619            // +-----+-----------+--------------+----+--+-----+-------+
620            //                                       |
621            // +---------------+-------+-------+-----v-+-------+-------+-------+
622            // | str fieldName | int32 | int16 | int32 | int16 | int32 | int16 |
623            // +---------------+---+---+---+---+---+---+----+--+---+---+---+---+
624            //                     |       |       |        |      |       |
625            //                     v       |       v        v      |       v
626            //                tableOID     |    typeOID  typeLen   |   formatCode
627            //                             v                       v
628            //                        colAttrNum               typeModifier
629            BeMessage::RowDescription(row_descs) => {
630                buf.put_u8(b'T');
631                write_body(buf, |buf| {
632                    buf.put_i16(row_descs.len() as i16); // # of fields
633                    for pg_field in row_descs {
634                        write_cstr(buf, pg_field.get_name().as_bytes())?;
635                        buf.put_i32(pg_field.get_table_oid()); // table oid
636                        buf.put_i16(pg_field.get_col_attr_num()); // attnum
637                        buf.put_i32(pg_field.get_type_oid());
638                        buf.put_i16(pg_field.get_type_len());
639                        buf.put_i32(pg_field.get_type_modifier()); // typmod
640                        buf.put_i16(pg_field.get_format_code()); // format code
641                    }
642                    Ok(())
643                })?;
644            }
645            // ReadyForQuery
646            // +-----+----------+---------------------------+
647            // | 'Z' | int32(5) | byte1(transaction status) |
648            // +-----+----------+---------------------------+
649            BeMessage::ReadyForQuery(txn_status) => {
650                buf.put_u8(b'Z');
651                buf.put_i32(5);
652                // TODO: add transaction status
653                buf.put_u8(match txn_status {
654                    TransactionStatus::Idle => b'I',
655                    TransactionStatus::InTransaction => b'T',
656                    TransactionStatus::InFailedTransaction => b'E',
657                });
658            }
659
660            BeMessage::ParseComplete => {
661                buf.put_u8(b'1');
662                write_body(buf, |_| Ok(()))?;
663            }
664
665            BeMessage::BindComplete => {
666                buf.put_u8(b'2');
667                write_body(buf, |_| Ok(()))?;
668            }
669
670            BeMessage::CloseComplete => {
671                buf.put_u8(b'3');
672                write_body(buf, |_| Ok(()))?;
673            }
674
675            BeMessage::PortalSuspended => {
676                buf.put_u8(b's');
677                write_body(buf, |_| Ok(()))?;
678            }
679            // ParameterDescription
680            // +-----+-----------+--------------------+---------------+-----+---------------+
681            // | 't' | int32 len | int16 ParameterNum | int32 typeOID | ... | int32 typeOID |
682            // +-----+-----------+-----------------+--+---------------+-----+---------------+
683            BeMessage::ParameterDescription(para_descs) => {
684                buf.put_u8(b't');
685                write_body(buf, |buf| {
686                    buf.put_i16(para_descs.len() as i16);
687                    for oid in para_descs {
688                        buf.put_i32(*oid);
689                    }
690                    Ok(())
691                })?;
692            }
693
694            BeMessage::NoData => {
695                buf.put_u8(b'n');
696                write_body(buf, |_| Ok(())).unwrap();
697            }
698
699            BeMessage::EncryptionResponseSsl => {
700                buf.put_u8(b'S');
701            }
702
703            BeMessage::EncryptionResponseGss => {
704                buf.put_u8(b'G');
705            }
706
707            BeMessage::EncryptionResponseNo => {
708                buf.put_u8(b'N');
709            }
710
711            // EmptyQueryResponse
712            // +-----+----------+
713            // | 'I' | int32(4) |
714            // +-----+----------+
715            BeMessage::EmptyQueryResponse => {
716                buf.put_u8(b'I');
717                buf.put_i32(4);
718            }
719
720            BeMessage::ErrorResponse {
721                error,
722                pretty,
723                severity,
724            } => {
725                // 'E' signalizes ErrorResponse messages
726                buf.put_u8(b'E');
727                // Format the error as a pretty report.
728                let error_message = match severity {
729                    Some(severity) => {
730                        ErrorOrNoticeMessage::error_with_severity(error, pretty, severity)
731                    }
732                    None => ErrorOrNoticeMessage::error(error, pretty),
733                };
734                write_err_or_notice(buf, &error_message)?;
735            }
736
737            BeMessage::BackendKeyData((process_id, secret_key)) => {
738                buf.put_u8(b'K');
739                write_body(buf, |buf| {
740                    buf.put_i32(process_id);
741                    buf.put_i32(secret_key);
742                    Ok(())
743                })?;
744            }
745            BeMessage::CopyOutResponse(col_num) => {
746                buf.put_u8(b'H');
747                write_body(buf, |buf| {
748                    buf.put_i8(Format::Text.to_i8());
749                    buf.put_i16(col_num as _);
750                    for _ in 0..col_num {
751                        buf.put_i16(Format::Text.to_i8() as _);
752                    }
753                    Ok(())
754                })?;
755            }
756            BeMessage::CopyData(row) => {
757                buf.put_u8(b'd');
758                // As in https://www.postgresql.org/docs/current/sql-copy.html, the default format is TSV format
759                write_body(buf, |buf| {
760                    fn write_str_bytes(
761                        buf: &mut BytesMut,
762                        str_bytes: &Option<Bytes>,
763                    ) -> Result<()> {
764                        let Some(str_bytes) = str_bytes else {
765                            return Ok(());
766                        };
767                        let s = String::from_utf8_lossy(str_bytes);
768                        for c in s.as_str().chars() {
769                            // As suggested in https://en.wikipedia.org/wiki/Tab-separated_values
770                            // we only escape "\t\b\r\\"
771                            match c {
772                                '\t' => {
773                                    buf.put_slice(b"\\t");
774                                }
775                                '\n' => {
776                                    buf.put_slice(b"\\n");
777                                }
778                                '\r' => {
779                                    buf.put_slice(b"\\r");
780                                }
781                                '\\' => {
782                                    buf.put_slice(b"\\\\");
783                                }
784                                _ => {
785                                    std::fmt::Write::write_char(buf, c).map_err(|_| {
786                                        Error::other(anyhow!("failed to write_char [{c}]"))
787                                    })?;
788                                }
789                            }
790                        }
791                        Ok(())
792                    }
793                    match row.values() {
794                        [] => {}
795                        [first, rest @ ..] => {
796                            write_str_bytes(buf, first)?;
797
798                            for rest in rest {
799                                buf.put_u8(b'\t');
800                                write_str_bytes(buf, rest)?;
801                            }
802                        }
803                    }
804                    buf.put_u8(b'\n');
805                    Ok(())
806                })?;
807            }
808            BeMessage::CopyDone => {
809                buf.put_u8(b'c');
810                write_body(buf, |_| Ok(()))?;
811            }
812        }
813
814        Ok(())
815    }
816}
817
818// Safe usize -> i32|i16 conversion, from rust-postgres
819trait FromUsize: Sized {
820    fn from_usize(x: usize) -> Result<Self>;
821}
822
823macro_rules! from_usize {
824    ($t:ty) => {
825        impl FromUsize for $t {
826            #[inline]
827            fn from_usize(x: usize) -> Result<$t> {
828                if x > <$t>::MAX as usize {
829                    Err(Error::new(ErrorKind::InvalidInput, "value too large to transmit").into())
830                } else {
831                    Ok(x as $t)
832                }
833            }
834        }
835    };
836}
837
838from_usize!(i32);
839
840/// Call f() to write body of the message and prepend it with 4-byte len as
841/// prescribed by the protocol. First write out body value and fill length value as i32 in front of
842/// it.
843fn write_body<F>(buf: &mut BytesMut, f: F) -> Result<()>
844where
845    F: FnOnce(&mut BytesMut) -> Result<()>,
846{
847    let base = buf.len();
848    buf.extend_from_slice(&[0; 4]);
849
850    f(buf)?;
851
852    let size = i32::from_usize(buf.len() - base)?;
853    BigEndian::write_i32(&mut buf[base..], size);
854    Ok(())
855}
856
857/// Safe write of s into buf as cstring (String in the protocol).
858fn write_cstr(buf: &mut BytesMut, s: &[u8]) -> Result<()> {
859    if s.contains(&0) {
860        return Err(Error::new(
861            ErrorKind::InvalidInput,
862            "string contains embedded null",
863        ));
864    }
865    buf.put_slice(s);
866    buf.put_u8(0);
867    Ok(())
868}
869
870/// Safe write error or notice message.
871fn write_err_or_notice(buf: &mut BytesMut, msg: &ErrorOrNoticeMessage<'_>) -> Result<()> {
872    write_body(buf, |buf| {
873        buf.put_u8(b'S'); // severity
874        write_cstr(buf, msg.severity.as_str().as_bytes())?;
875
876        buf.put_u8(b'C'); // SQLSTATE error code
877        write_cstr(buf, msg.error_code.sqlstate().as_bytes())?;
878
879        buf.put_u8(b'M'); // the message
880        write_cstr(buf, msg.message.as_bytes())?;
881
882        buf.put_u8(0); // terminator
883        Ok(())
884    })
885}
886
887#[cfg(test)]
888mod tests {
889    use bytes::Bytes;
890
891    use crate::pg_message::FeQueryMessage;
892
893    #[test]
894    fn test_get_sql() {
895        let fe = FeQueryMessage {
896            sql_bytes: Bytes::from(vec![255, 255, 255, 255, 255, 255, 0]),
897        };
898        assert!(fe.get_sql().is_err(), "{}", true);
899        let fe = FeQueryMessage {
900            sql_bytes: Bytes::from(vec![1, 2, 3, 4, 5, 6, 7, 8]),
901        };
902        assert!(fe.get_sql().is_err(), "{}", true);
903    }
904}