pgwire/
pg_message.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ffi::CStr;
17use std::io::{Error, ErrorKind, IoSlice, Result, Write};
18
19use anyhow::anyhow;
20use byteorder::{BigEndian, ByteOrder};
21/// Part of code learned from <https://github.com/zenithdb/zenith/blob/main/zenith_utils/src/pq_proto.rs>.
22use bytes::{Buf, BufMut, Bytes, BytesMut};
23use peekable::tokio::AsyncPeekable;
24use tokio::io::{AsyncRead, AsyncReadExt};
25
26use crate::error_or_notice::ErrorOrNoticeMessage;
27use crate::pg_field_descriptor::PgFieldDescriptor;
28use crate::pg_response::StatementType;
29use crate::types::{Format, Row};
30
31/// Messages that can be sent from pg client to server. Implement `read`.
32#[derive(Debug)]
33pub enum FeMessage {
34    Ssl,
35    Gss,
36    Startup(FeStartupMessage),
37    Query(FeQueryMessage),
38    Parse(FeParseMessage),
39    Password(FePasswordMessage),
40    Describe(FeDescribeMessage),
41    Bind(FeBindMessage),
42    Execute(FeExecuteMessage),
43    Close(FeCloseMessage),
44    Sync,
45    CancelQuery(FeCancelMessage),
46    Terminate,
47    Flush,
48    // special msg to detect health check, which represents the client immediately closes the connection cleanly without sending any data.
49    HealthCheck,
50    // The original message has been rejected due to server throttling. This is a placeholder message generated by server.
51    ServerThrottle(ServerThrottleReason),
52}
53
54#[derive(Debug)]
55pub enum ServerThrottleReason {
56    TooLargeMessage,
57    TooManyMemoryUsage,
58}
59
60#[derive(Debug)]
61pub struct FeStartupMessage {
62    pub config: HashMap<String, String>,
63}
64
65impl FeStartupMessage {
66    pub fn build_with_payload(payload: &[u8]) -> Result<Self> {
67        let config = match std::str::from_utf8(payload) {
68            Ok(v) => Ok(v.trim_end_matches('\0')),
69            Err(err) => Err(Error::new(
70                ErrorKind::InvalidInput,
71                anyhow!(err).context("Input end error"),
72            )),
73        }?;
74        let mut map = HashMap::new();
75        let config: Vec<&str> = config.split('\0').collect();
76        if config.len() % 2 == 1 {
77            return Err(Error::new(
78                ErrorKind::InvalidInput,
79                "Invalid input config: odd number of config pairs",
80            ));
81        }
82        config.chunks(2).for_each(|chunk| {
83            map.insert(chunk[0].to_owned(), chunk[1].to_owned());
84        });
85        Ok(FeStartupMessage { config: map })
86    }
87}
88
89/// Query message contains the string sql.
90#[derive(Debug)]
91pub struct FeQueryMessage {
92    pub sql_bytes: Bytes,
93}
94
95#[derive(Debug)]
96pub struct FeBindMessage {
97    pub param_format_codes: Vec<i16>,
98    pub result_format_codes: Vec<i16>,
99
100    pub params: Vec<Option<Bytes>>,
101    pub portal_name: Bytes,
102    pub statement_name: Bytes,
103}
104
105#[derive(Debug)]
106pub struct FeExecuteMessage {
107    pub portal_name: Bytes,
108    pub max_rows: i32,
109}
110
111#[derive(Debug)]
112pub struct FeParseMessage {
113    pub statement_name: Bytes,
114    pub sql_bytes: Bytes,
115    pub type_ids: Vec<i32>,
116}
117
118#[derive(Debug)]
119pub struct FePasswordMessage {
120    pub password: Bytes,
121}
122
123#[derive(Debug)]
124pub struct FeDescribeMessage {
125    // 'S' to describe a prepared statement; or 'P' to describe a portal.
126    pub kind: u8,
127    pub name: Bytes,
128}
129
130#[derive(Debug)]
131pub struct FeCloseMessage {
132    pub kind: u8,
133    pub name: Bytes,
134}
135
136#[derive(Debug)]
137pub struct FeCancelMessage {
138    pub target_process_id: i32,
139    pub target_secret_key: i32,
140}
141
142impl FeCancelMessage {
143    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
144        let target_process_id = buf.get_i32();
145        let target_secret_key = buf.get_i32();
146        Ok(FeMessage::CancelQuery(Self {
147            target_process_id,
148            target_secret_key,
149        }))
150    }
151}
152
153impl FeDescribeMessage {
154    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
155        let kind = buf.get_u8();
156        let name = read_null_terminated(&mut buf)?;
157
158        Ok(FeMessage::Describe(FeDescribeMessage { kind, name }))
159    }
160}
161
162impl FeBindMessage {
163    // Bind Message Header
164    // +-----+-----------+
165    // | 'B' | int32 len |
166    // +-----+-----------+
167    // Bind Message Body
168    // +----------------+---------------+
169    // | str portalname | str statement |
170    // +----------------+---------------+
171    // +---------------------+------------------+-------+
172    // | int16 numFormatCode | int16 FormatCode |  ...  |
173    // +---------------------+------------------+-------+
174    // +-----------------+-------------------+---------------+
175    // | int16 numParams | int32 valueLength |  byte value.. |
176    // +-----------------+-------------------+---------------+
177    // +----------------------------------+------------------+-------+
178    // | int16 numResultColumnFormatCodes | int16 FormatCode |  ...  |
179    // +----------------------------------+------------------+-------+
180    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
181        let portal_name = read_null_terminated(&mut buf)?;
182        let statement_name = read_null_terminated(&mut buf)?;
183
184        let len = buf.get_i16();
185        let param_format_codes = (0..len).map(|_| buf.get_i16()).collect();
186
187        // Read Params
188        let len = buf.get_i16();
189        let params = (0..len)
190            .map(|_| {
191                let val_len = buf.get_i32();
192                if val_len == -1 {
193                    None
194                } else {
195                    Some(buf.copy_to_bytes(val_len as usize))
196                }
197            })
198            .collect();
199
200        let len = buf.get_i16();
201        let result_format_codes = (0..len).map(|_| buf.get_i16()).collect();
202
203        Ok(FeMessage::Bind(FeBindMessage {
204            param_format_codes,
205            result_format_codes,
206            params,
207            portal_name,
208            statement_name,
209        }))
210    }
211}
212
213impl FeExecuteMessage {
214    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
215        let portal_name = read_null_terminated(&mut buf)?;
216        let max_rows = buf.get_i32();
217
218        Ok(FeMessage::Execute(FeExecuteMessage {
219            portal_name,
220            max_rows,
221        }))
222    }
223}
224
225impl FeParseMessage {
226    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
227        let statement_name = read_null_terminated(&mut buf)?;
228        let sql_bytes = read_null_terminated(&mut buf)?;
229        let nparams = buf.get_i16();
230
231        let type_ids: Vec<i32> = (0..nparams).map(|_| buf.get_i32()).collect();
232
233        Ok(FeMessage::Parse(FeParseMessage {
234            statement_name,
235            sql_bytes,
236            type_ids,
237        }))
238    }
239}
240
241impl FePasswordMessage {
242    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
243        let password = read_null_terminated(&mut buf)?;
244
245        Ok(FeMessage::Password(FePasswordMessage { password }))
246    }
247}
248
249impl FeQueryMessage {
250    pub fn get_sql(&self) -> Result<&str> {
251        match CStr::from_bytes_with_nul(&self.sql_bytes) {
252            Ok(cstr) => cstr.to_str().map_err(|err| {
253                Error::new(
254                    ErrorKind::InvalidInput,
255                    anyhow!(err).context("Invalid UTF-8 sequence"),
256                )
257            }),
258            Err(err) => Err(Error::new(
259                ErrorKind::InvalidInput,
260                anyhow!(err).context("Input end error"),
261            )),
262        }
263    }
264}
265
266impl FeCloseMessage {
267    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
268        let kind = buf.get_u8();
269        let name = read_null_terminated(&mut buf)?;
270        Ok(FeMessage::Close(FeCloseMessage { kind, name }))
271    }
272}
273
274#[derive(Clone)]
275pub struct FeMessageHeader {
276    pub tag: u8,
277    pub payload_len: i32,
278}
279
280impl FeMessage {
281    /// Read one message from the stream.
282    pub async fn read_header(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessageHeader> {
283        let tag = stream.read_u8().await?;
284        let len = stream.read_i32().await?;
285
286        let payload_len = len - 4;
287        Ok(FeMessageHeader { tag, payload_len })
288    }
289
290    /// Read one message from the stream.
291    pub async fn read_body(
292        stream: &mut (impl AsyncRead + Unpin),
293        header: FeMessageHeader,
294    ) -> Result<FeMessage> {
295        let FeMessageHeader { tag, payload_len } = header;
296        let mut payload: Vec<u8> = vec![0; payload_len as usize];
297        if payload_len > 0 {
298            stream.read_exact(&mut payload).await?;
299        }
300        let sql_bytes = Bytes::from(payload);
301        match tag {
302            b'Q' => Ok(FeMessage::Query(FeQueryMessage { sql_bytes })),
303            b'P' => FeParseMessage::parse(sql_bytes),
304            b'D' => FeDescribeMessage::parse(sql_bytes),
305            b'B' => FeBindMessage::parse(sql_bytes),
306            b'E' => FeExecuteMessage::parse(sql_bytes),
307            b'S' => Ok(FeMessage::Sync),
308            b'X' => Ok(FeMessage::Terminate),
309            b'C' => FeCloseMessage::parse(sql_bytes),
310            b'p' => FePasswordMessage::parse(sql_bytes),
311            b'H' => Ok(FeMessage::Flush),
312            _ => Err(std::io::Error::new(
313                ErrorKind::InvalidInput,
314                format!("Unsupported tag of regular message: {}", tag),
315            )),
316        }
317    }
318
319    pub async fn skip_body(
320        stream: &mut (impl AsyncRead + Unpin),
321        header: FeMessageHeader,
322    ) -> Result<()> {
323        let FeMessageHeader {
324            tag: _,
325            payload_len,
326        } = header;
327
328        if payload_len > 0 {
329            // Use smaller batches to process the payload instead of handling it all at once to minimize memory usage.
330            const BUF_SIZE: usize = 1024;
331            let mut buf: Vec<u8> = vec![0; BUF_SIZE];
332            for _ in 0..(payload_len as usize) / BUF_SIZE {
333                stream.read_exact(&mut buf).await?;
334            }
335            let remain = (payload_len as usize) % BUF_SIZE;
336            if remain > 0 {
337                buf.truncate(remain);
338                stream.read_exact(&mut buf).await?;
339            }
340        }
341        Ok(())
342    }
343}
344
345impl FeStartupMessage {
346    /// Read startup message from the stream.
347    pub async fn read(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessage> {
348        let mut stream = AsyncPeekable::new(stream);
349
350        if let Err(err) = stream.peek_exact(&mut [0; 1]).await {
351            // If the stream is empty, it can be a health check. Do not return error.
352            if err.kind() == ErrorKind::UnexpectedEof {
353                return Ok(FeMessage::HealthCheck);
354            } else {
355                return Err(err);
356            }
357        }
358
359        let len = stream.read_i32().await?;
360        let protocol_num = stream.read_i32().await?;
361        let payload_len = (len - 8) as usize;
362        if payload_len >= isize::MAX as usize {
363            return Err(std::io::Error::new(
364                ErrorKind::InvalidInput,
365                format!("Payload length has exceed usize::MAX {:?}", payload_len),
366            ));
367        }
368        let mut payload = vec![0; payload_len];
369        if payload_len > 0 {
370            stream.read_exact(&mut payload).await?;
371        }
372        match protocol_num {
373            // code from: https://www.postgresql.org/docs/current/protocol-message-formats.html
374            196608 => Ok(FeMessage::Startup(FeStartupMessage::build_with_payload(
375                &payload,
376            )?)),
377            80877104 => Ok(FeMessage::Gss),
378            80877103 => Ok(FeMessage::Ssl),
379            // Cancel request code.
380            80877102 => FeCancelMessage::parse(Bytes::from(payload)),
381            _ => Err(std::io::Error::new(
382                ErrorKind::InvalidInput,
383                format!(
384                    "Unsupported protocol number in start up msg {:?}",
385                    protocol_num
386                ),
387            )),
388        }
389    }
390}
391
392/// Continue read until reached a \0. Used in reading string from Bytes.
393fn read_null_terminated(buf: &mut Bytes) -> Result<Bytes> {
394    let mut result = BytesMut::new();
395
396    loop {
397        if !buf.has_remaining() {
398            panic!("no null-terminator in string");
399        }
400
401        let byte = buf.get_u8();
402
403        if byte == 0 {
404            break;
405        }
406        result.put_u8(byte);
407    }
408    Ok(result.freeze())
409}
410
411/// Message sent from server to psql client. Implement `write` (how to serialize it into psql
412/// buffer).
413/// Ref: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
414#[derive(Debug, Clone, Copy)]
415pub enum BeMessage<'a> {
416    AuthenticationOk,
417    AuthenticationCleartextPassword,
418    AuthenticationMd5Password(&'a [u8; 4]),
419    CommandComplete(BeCommandCompleteMessage),
420    NoticeResponse(&'a str),
421    // Single byte - used in response to SSLRequest/GSSENCRequest.
422    EncryptionResponseSsl,
423    EncryptionResponseGss,
424    EncryptionResponseNo,
425    EmptyQueryResponse,
426    ParseComplete,
427    BindComplete,
428    PortalSuspended,
429    // array of parameter oid(i32)
430    ParameterDescription(&'a [i32]),
431    NoData,
432    DataRow(&'a Row),
433    ParameterStatus(BeParameterStatusMessage<'a>),
434    ReadyForQuery(TransactionStatus),
435    RowDescription(&'a [PgFieldDescriptor]),
436    ErrorResponse(&'a (dyn std::error::Error + Send + Sync + 'static)),
437    CloseComplete,
438
439    // Copy
440    CopyOutResponse(usize),
441    CopyData(&'a Row),
442    CopyDone,
443
444    // 0: process ID, 1: secret key
445    BackendKeyData((i32, i32)),
446}
447
448#[derive(Debug, Copy, Clone)]
449pub enum BeParameterStatusMessage<'a> {
450    ClientEncoding(&'a str),
451    StandardConformingString(&'a str),
452    ServerVersion(&'a str),
453    ApplicationName(&'a str),
454    TimeZone(&'a str),
455}
456
457#[derive(Debug, Copy, Clone)]
458pub struct BeCommandCompleteMessage {
459    pub stmt_type: StatementType,
460    pub rows_cnt: i32,
461}
462
463#[derive(Debug, Clone, Copy)]
464pub enum TransactionStatus {
465    Idle,
466    InTransaction,
467    InFailedTransaction,
468}
469
470impl BeMessage<'_> {
471    /// Write message to the given buf.
472    pub fn write(buf: &mut BytesMut, message: BeMessage<'_>) -> Result<()> {
473        match message {
474            // AuthenticationOk
475            // +-----+----------+-----------+
476            // | 'R' | int32(8) | int32(0)  |
477            // +-----+----------+-----------+
478            BeMessage::AuthenticationOk => {
479                buf.put_u8(b'R');
480                buf.put_i32(8);
481                buf.put_i32(0);
482            }
483
484            // AuthenticationCleartextPassword
485            // +-----+----------+-----------+
486            // | 'R' | int32(8) | int32(3)  |
487            // +-----+----------+-----------+
488            BeMessage::AuthenticationCleartextPassword => {
489                buf.put_u8(b'R');
490                buf.put_i32(8);
491                buf.put_i32(3);
492            }
493
494            // AuthenticationMD5Password
495            // +-----+----------+-----------+----------------+
496            // | 'R' | int32(12) | int32(5)  |  Byte4(salt)  |
497            // +-----+----------+-----------+----------------+
498            //
499            // The 4-byte random salt will be used by client to send encrypted password as
500            // concat('md5', md5(concat(md5(concat(password, username)), random-salt))).
501            BeMessage::AuthenticationMd5Password(salt) => {
502                buf.put_u8(b'R');
503                buf.put_i32(12);
504                buf.put_i32(5);
505                buf.put_slice(&salt[..]);
506            }
507
508            // ParameterStatus
509            // +-----+-----------+----------+------+-----------+------+
510            // | 'S' | int32 len | str name | '\0' | str value | '\0' |
511            // +-----+-----------+----------+------+-----------+------+
512            BeMessage::ParameterStatus(param) => {
513                use BeParameterStatusMessage::*;
514                let [name, value] = match param {
515                    ClientEncoding(val) => [b"client_encoding", val.as_bytes()],
516                    StandardConformingString(val) => {
517                        [b"standard_conforming_strings", val.as_bytes()]
518                    }
519                    ServerVersion(val) => [b"server_version", val.as_bytes()],
520                    ApplicationName(val) => [b"application_name", val.as_bytes()],
521                    // psycopg3 is case-sensitive, so we use "TimeZone" instead of "timezone" #18079
522                    TimeZone(val) => [b"TimeZone", val.as_bytes()],
523                };
524
525                // Parameter names and values are passed as null-terminated strings
526                let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
527                let mut buffer = vec![];
528                let cnt = buffer.write_vectored(iov).unwrap();
529
530                buf.put_u8(b'S');
531                write_body(buf, |stream| {
532                    stream.put_slice(&buffer[..cnt]);
533                    Ok(())
534                })
535                .unwrap();
536            }
537
538            // CommandComplete
539            // +-----+-----------+-----------------+
540            // | 'C' | int32 len | str commandTag  |
541            // +-----+-----------+-----------------+
542            BeMessage::CommandComplete(cmd) => {
543                let rows_cnt = cmd.rows_cnt;
544                let mut stmt_type = cmd.stmt_type;
545                let mut tag = "".to_owned();
546                stmt_type = match stmt_type {
547                    StatementType::INSERT_RETURNING => StatementType::INSERT,
548                    StatementType::DELETE_RETURNING => StatementType::DELETE,
549                    StatementType::UPDATE_RETURNING => StatementType::UPDATE,
550                    s => s,
551                };
552                tag.push_str(&stmt_type.to_string());
553                if stmt_type == StatementType::INSERT {
554                    tag.push_str(" 0");
555                }
556                if stmt_type.is_command() {
557                    tag.push(' ');
558                    tag.push_str(&rows_cnt.to_string());
559                }
560                buf.put_u8(b'C');
561                write_body(buf, |buf| {
562                    write_cstr(buf, tag.as_bytes())?;
563                    Ok(())
564                })?;
565            }
566
567            // NoticeResponse
568            // +-----+-----------+------------------+------------------+
569            // | 'N' | int32 len | byte1 field type | str field value  |
570            // +-----+-----------+------------------+-+----------------+
571            // description of the fields can be found here:
572            // https://www.postgresql.org/docs/current/protocol-error-fields.html
573            BeMessage::NoticeResponse(notice) => {
574                buf.put_u8(b'N');
575                write_err_or_notice(buf, &ErrorOrNoticeMessage::notice(notice))?;
576            }
577
578            // DataRow
579            // +-----+-----------+--------------+--------+-----+--------+
580            // | 'D' | int32 len | int16 colNum | column | ... | column |
581            // +-----+-----------+--------------+----+---+-----+--------+
582            //                                       |
583            //                          +-----------+v------+
584            //                          | int32 len | bytes |
585            //                          +-----------+-------+
586            BeMessage::DataRow(vals) => {
587                buf.put_u8(b'D');
588                write_body(buf, |buf| {
589                    buf.put_u16(vals.len() as u16); // num of cols
590                    for val_opt in vals.values() {
591                        if let Some(val) = val_opt {
592                            buf.put_u32(val.len() as u32);
593                            buf.put_slice(val);
594                        } else {
595                            buf.put_i32(-1);
596                        }
597                    }
598                    Ok(())
599                })
600                .unwrap();
601            }
602
603            // RowDescription
604            // +-----+-----------+--------------+-------+-----+-------+
605            // | 'T' | int32 len | int16 colNum | field | ... | field |
606            // +-----+-----------+--------------+----+--+-----+-------+
607            //                                       |
608            // +---------------+-------+-------+-----v-+-------+-------+-------+
609            // | str fieldName | int32 | int16 | int32 | int16 | int32 | int16 |
610            // +---------------+---+---+---+---+---+---+----+--+---+---+---+---+
611            //                     |       |       |        |      |       |
612            //                     v       |       v        v      |       v
613            //                tableOID     |    typeOID  typeLen   |   formatCode
614            //                             v                       v
615            //                        colAttrNum               typeModifier
616            BeMessage::RowDescription(row_descs) => {
617                buf.put_u8(b'T');
618                write_body(buf, |buf| {
619                    buf.put_i16(row_descs.len() as i16); // # of fields
620                    for pg_field in row_descs {
621                        write_cstr(buf, pg_field.get_name().as_bytes())?;
622                        buf.put_i32(pg_field.get_table_oid()); // table oid
623                        buf.put_i16(pg_field.get_col_attr_num()); // attnum
624                        buf.put_i32(pg_field.get_type_oid());
625                        buf.put_i16(pg_field.get_type_len());
626                        buf.put_i32(pg_field.get_type_modifier()); // typmod
627                        buf.put_i16(pg_field.get_format_code()); // format code
628                    }
629                    Ok(())
630                })?;
631            }
632            // ReadyForQuery
633            // +-----+----------+---------------------------+
634            // | 'Z' | int32(5) | byte1(transaction status) |
635            // +-----+----------+---------------------------+
636            BeMessage::ReadyForQuery(txn_status) => {
637                buf.put_u8(b'Z');
638                buf.put_i32(5);
639                // TODO: add transaction status
640                buf.put_u8(match txn_status {
641                    TransactionStatus::Idle => b'I',
642                    TransactionStatus::InTransaction => b'T',
643                    TransactionStatus::InFailedTransaction => b'E',
644                });
645            }
646
647            BeMessage::ParseComplete => {
648                buf.put_u8(b'1');
649                write_body(buf, |_| Ok(()))?;
650            }
651
652            BeMessage::BindComplete => {
653                buf.put_u8(b'2');
654                write_body(buf, |_| Ok(()))?;
655            }
656
657            BeMessage::CloseComplete => {
658                buf.put_u8(b'3');
659                write_body(buf, |_| Ok(()))?;
660            }
661
662            BeMessage::PortalSuspended => {
663                buf.put_u8(b's');
664                write_body(buf, |_| Ok(()))?;
665            }
666            // ParameterDescription
667            // +-----+-----------+--------------------+---------------+-----+---------------+
668            // | 't' | int32 len | int16 ParameterNum | int32 typeOID | ... | int32 typeOID |
669            // +-----+-----------+-----------------+--+---------------+-----+---------------+
670            BeMessage::ParameterDescription(para_descs) => {
671                buf.put_u8(b't');
672                write_body(buf, |buf| {
673                    buf.put_i16(para_descs.len() as i16);
674                    for oid in para_descs {
675                        buf.put_i32(*oid);
676                    }
677                    Ok(())
678                })?;
679            }
680
681            BeMessage::NoData => {
682                buf.put_u8(b'n');
683                write_body(buf, |_| Ok(())).unwrap();
684            }
685
686            BeMessage::EncryptionResponseSsl => {
687                buf.put_u8(b'S');
688            }
689
690            BeMessage::EncryptionResponseGss => {
691                buf.put_u8(b'G');
692            }
693
694            BeMessage::EncryptionResponseNo => {
695                buf.put_u8(b'N');
696            }
697
698            // EmptyQueryResponse
699            // +-----+----------+
700            // | 'I' | int32(4) |
701            // +-----+----------+
702            BeMessage::EmptyQueryResponse => {
703                buf.put_u8(b'I');
704                buf.put_i32(4);
705            }
706
707            BeMessage::ErrorResponse(error) => {
708                // 'E' signalizes ErrorResponse messages
709                buf.put_u8(b'E');
710                // Format the error as a pretty report.
711                write_err_or_notice(buf, &ErrorOrNoticeMessage::error(error))?;
712            }
713
714            BeMessage::BackendKeyData((process_id, secret_key)) => {
715                buf.put_u8(b'K');
716                write_body(buf, |buf| {
717                    buf.put_i32(process_id);
718                    buf.put_i32(secret_key);
719                    Ok(())
720                })?;
721            }
722            BeMessage::CopyOutResponse(col_num) => {
723                buf.put_u8(b'H');
724                write_body(buf, |buf| {
725                    buf.put_i8(Format::Text.to_i8());
726                    buf.put_i16(col_num as _);
727                    for _ in 0..col_num {
728                        buf.put_i16(Format::Text.to_i8() as _);
729                    }
730                    Ok(())
731                })?;
732            }
733            BeMessage::CopyData(row) => {
734                buf.put_u8(b'd');
735                // As in https://www.postgresql.org/docs/current/sql-copy.html, the default format is TSV format
736                write_body(buf, |buf| {
737                    fn write_str_bytes(
738                        buf: &mut BytesMut,
739                        str_bytes: &Option<Bytes>,
740                    ) -> Result<()> {
741                        let Some(str_bytes) = str_bytes else {
742                            return Ok(());
743                        };
744                        let s = String::from_utf8_lossy(str_bytes);
745                        for c in s.as_str().chars() {
746                            // As suggested in https://en.wikipedia.org/wiki/Tab-separated_values
747                            // we only escape "\t\b\r\\"
748                            match c {
749                                '\t' => {
750                                    buf.put_slice(b"\\t");
751                                }
752                                '\n' => {
753                                    buf.put_slice(b"\\n");
754                                }
755                                '\r' => {
756                                    buf.put_slice(b"\\r");
757                                }
758                                '\\' => {
759                                    buf.put_slice(b"\\\\");
760                                }
761                                _ => {
762                                    std::fmt::Write::write_char(buf, c).map_err(|_| {
763                                        Error::other(anyhow!("failed to write_char [{c}]"))
764                                    })?;
765                                }
766                            }
767                        }
768                        Ok(())
769                    }
770                    match row.values() {
771                        [] => {}
772                        [first, rest @ ..] => {
773                            write_str_bytes(buf, first)?;
774
775                            for rest in rest {
776                                buf.put_u8(b'\t');
777                                write_str_bytes(buf, rest)?;
778                            }
779                        }
780                    }
781                    buf.put_u8(b'\n');
782                    Ok(())
783                })?;
784            }
785            BeMessage::CopyDone => {
786                buf.put_u8(b'c');
787                write_body(buf, |_| Ok(()))?;
788            }
789        }
790
791        Ok(())
792    }
793}
794
795// Safe usize -> i32|i16 conversion, from rust-postgres
796trait FromUsize: Sized {
797    fn from_usize(x: usize) -> Result<Self>;
798}
799
800macro_rules! from_usize {
801    ($t:ty) => {
802        impl FromUsize for $t {
803            #[inline]
804            fn from_usize(x: usize) -> Result<$t> {
805                if x > <$t>::MAX as usize {
806                    Err(Error::new(ErrorKind::InvalidInput, "value too large to transmit").into())
807                } else {
808                    Ok(x as $t)
809                }
810            }
811        }
812    };
813}
814
815from_usize!(i32);
816
817/// Call f() to write body of the message and prepend it with 4-byte len as
818/// prescribed by the protocol. First write out body value and fill length value as i32 in front of
819/// it.
820fn write_body<F>(buf: &mut BytesMut, f: F) -> Result<()>
821where
822    F: FnOnce(&mut BytesMut) -> Result<()>,
823{
824    let base = buf.len();
825    buf.extend_from_slice(&[0; 4]);
826
827    f(buf)?;
828
829    let size = i32::from_usize(buf.len() - base)?;
830    BigEndian::write_i32(&mut buf[base..], size);
831    Ok(())
832}
833
834/// Safe write of s into buf as cstring (String in the protocol).
835fn write_cstr(buf: &mut BytesMut, s: &[u8]) -> Result<()> {
836    if s.contains(&0) {
837        return Err(Error::new(
838            ErrorKind::InvalidInput,
839            "string contains embedded null",
840        ));
841    }
842    buf.put_slice(s);
843    buf.put_u8(0);
844    Ok(())
845}
846
847/// Safe write error or notice message.
848fn write_err_or_notice(buf: &mut BytesMut, msg: &ErrorOrNoticeMessage<'_>) -> Result<()> {
849    write_body(buf, |buf| {
850        buf.put_u8(b'S'); // severity
851        write_cstr(buf, msg.severity.as_str().as_bytes())?;
852
853        buf.put_u8(b'C'); // SQLSTATE error code
854        write_cstr(buf, msg.error_code.sqlstate().as_bytes())?;
855
856        buf.put_u8(b'M'); // the message
857        write_cstr(buf, msg.message.as_bytes())?;
858
859        buf.put_u8(0); // terminator
860        Ok(())
861    })
862}
863
864#[cfg(test)]
865mod tests {
866    use bytes::Bytes;
867
868    use crate::pg_message::FeQueryMessage;
869
870    #[test]
871    fn test_get_sql() {
872        let fe = FeQueryMessage {
873            sql_bytes: Bytes::from(vec![255, 255, 255, 255, 255, 255, 0]),
874        };
875        assert!(fe.get_sql().is_err(), "{}", true);
876        let fe = FeQueryMessage {
877            sql_bytes: Bytes::from(vec![1, 2, 3, 4, 5, 6, 7, 8]),
878        };
879        assert!(fe.get_sql().is_err(), "{}", true);
880    }
881}