pgwire/
pg_message.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ffi::CStr;
17use std::io::{Error, ErrorKind, IoSlice, Result, Write};
18
19use anyhow::anyhow;
20use byteorder::{BigEndian, ByteOrder};
21/// Part of code learned from <https://github.com/zenithdb/zenith/blob/main/zenith_utils/src/pq_proto.rs>.
22use bytes::{Buf, BufMut, Bytes, BytesMut};
23use peekable::tokio::AsyncPeekable;
24use tokio::io::{AsyncRead, AsyncReadExt};
25
26use crate::error_or_notice::ErrorOrNoticeMessage;
27use crate::pg_field_descriptor::PgFieldDescriptor;
28use crate::pg_response::StatementType;
29use crate::pg_server::BoxedError;
30use crate::types::Row;
31
32/// Messages that can be sent from pg client to server. Implement `read`.
33#[derive(Debug)]
34pub enum FeMessage {
35    Ssl,
36    Gss,
37    Startup(FeStartupMessage),
38    Query(FeQueryMessage),
39    Parse(FeParseMessage),
40    Password(FePasswordMessage),
41    Describe(FeDescribeMessage),
42    Bind(FeBindMessage),
43    Execute(FeExecuteMessage),
44    Close(FeCloseMessage),
45    Sync,
46    CancelQuery(FeCancelMessage),
47    Terminate,
48    Flush,
49    // special msg to detect health check, which represents the client immediately closes the connection cleanly without sending any data.
50    HealthCheck,
51    // The original message has been rejected due to server throttling. This is a placeholder message generated by server.
52    ServerThrottle(ServerThrottleReason),
53}
54
55#[derive(Debug)]
56pub enum ServerThrottleReason {
57    TooLargeMessage,
58    TooManyMemoryUsage,
59}
60
61#[derive(Debug)]
62pub struct FeStartupMessage {
63    pub config: HashMap<String, String>,
64}
65
66impl FeStartupMessage {
67    pub fn build_with_payload(payload: &[u8]) -> Result<Self> {
68        let config = match std::str::from_utf8(payload) {
69            Ok(v) => Ok(v.trim_end_matches('\0')),
70            Err(err) => Err(Error::new(
71                ErrorKind::InvalidInput,
72                anyhow!(err).context("Input end error"),
73            )),
74        }?;
75        let mut map = HashMap::new();
76        let config: Vec<&str> = config.split('\0').collect();
77        if config.len() % 2 == 1 {
78            return Err(Error::new(
79                ErrorKind::InvalidInput,
80                "Invalid input config: odd number of config pairs",
81            ));
82        }
83        config.chunks(2).for_each(|chunk| {
84            map.insert(chunk[0].to_owned(), chunk[1].to_owned());
85        });
86        Ok(FeStartupMessage { config: map })
87    }
88}
89
90/// Query message contains the string sql.
91#[derive(Debug)]
92pub struct FeQueryMessage {
93    pub sql_bytes: Bytes,
94}
95
96#[derive(Debug)]
97pub struct FeBindMessage {
98    pub param_format_codes: Vec<i16>,
99    pub result_format_codes: Vec<i16>,
100
101    pub params: Vec<Option<Bytes>>,
102    pub portal_name: Bytes,
103    pub statement_name: Bytes,
104}
105
106#[derive(Debug)]
107pub struct FeExecuteMessage {
108    pub portal_name: Bytes,
109    pub max_rows: i32,
110}
111
112#[derive(Debug)]
113pub struct FeParseMessage {
114    pub statement_name: Bytes,
115    pub sql_bytes: Bytes,
116    pub type_ids: Vec<i32>,
117}
118
119#[derive(Debug)]
120pub struct FePasswordMessage {
121    pub password: Bytes,
122}
123
124#[derive(Debug)]
125pub struct FeDescribeMessage {
126    // 'S' to describe a prepared statement; or 'P' to describe a portal.
127    pub kind: u8,
128    pub name: Bytes,
129}
130
131#[derive(Debug)]
132pub struct FeCloseMessage {
133    pub kind: u8,
134    pub name: Bytes,
135}
136
137#[derive(Debug)]
138pub struct FeCancelMessage {
139    pub target_process_id: i32,
140    pub target_secret_key: i32,
141}
142
143impl FeCancelMessage {
144    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
145        let target_process_id = buf.get_i32();
146        let target_secret_key = buf.get_i32();
147        Ok(FeMessage::CancelQuery(Self {
148            target_process_id,
149            target_secret_key,
150        }))
151    }
152}
153
154impl FeDescribeMessage {
155    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
156        let kind = buf.get_u8();
157        let name = read_null_terminated(&mut buf)?;
158
159        Ok(FeMessage::Describe(FeDescribeMessage { kind, name }))
160    }
161}
162
163impl FeBindMessage {
164    // Bind Message Header
165    // +-----+-----------+
166    // | 'B' | int32 len |
167    // +-----+-----------+
168    // Bind Message Body
169    // +----------------+---------------+
170    // | str portalname | str statement |
171    // +----------------+---------------+
172    // +---------------------+------------------+-------+
173    // | int16 numFormatCode | int16 FormatCode |  ...  |
174    // +---------------------+------------------+-------+
175    // +-----------------+-------------------+---------------+
176    // | int16 numParams | int32 valueLength |  byte value.. |
177    // +-----------------+-------------------+---------------+
178    // +----------------------------------+------------------+-------+
179    // | int16 numResultColumnFormatCodes | int16 FormatCode |  ...  |
180    // +----------------------------------+------------------+-------+
181    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
182        let portal_name = read_null_terminated(&mut buf)?;
183        let statement_name = read_null_terminated(&mut buf)?;
184
185        let len = buf.get_i16();
186        let param_format_codes = (0..len).map(|_| buf.get_i16()).collect();
187
188        // Read Params
189        let len = buf.get_i16();
190        let params = (0..len)
191            .map(|_| {
192                let val_len = buf.get_i32();
193                if val_len == -1 {
194                    None
195                } else {
196                    Some(buf.copy_to_bytes(val_len as usize))
197                }
198            })
199            .collect();
200
201        let len = buf.get_i16();
202        let result_format_codes = (0..len).map(|_| buf.get_i16()).collect();
203
204        Ok(FeMessage::Bind(FeBindMessage {
205            param_format_codes,
206            result_format_codes,
207            params,
208            portal_name,
209            statement_name,
210        }))
211    }
212}
213
214impl FeExecuteMessage {
215    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
216        let portal_name = read_null_terminated(&mut buf)?;
217        let max_rows = buf.get_i32();
218
219        Ok(FeMessage::Execute(FeExecuteMessage {
220            portal_name,
221            max_rows,
222        }))
223    }
224}
225
226impl FeParseMessage {
227    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
228        let statement_name = read_null_terminated(&mut buf)?;
229        let sql_bytes = read_null_terminated(&mut buf)?;
230        let nparams = buf.get_i16();
231
232        let type_ids: Vec<i32> = (0..nparams).map(|_| buf.get_i32()).collect();
233
234        Ok(FeMessage::Parse(FeParseMessage {
235            statement_name,
236            sql_bytes,
237            type_ids,
238        }))
239    }
240}
241
242impl FePasswordMessage {
243    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
244        let password = read_null_terminated(&mut buf)?;
245
246        Ok(FeMessage::Password(FePasswordMessage { password }))
247    }
248}
249
250impl FeQueryMessage {
251    pub fn get_sql(&self) -> Result<&str> {
252        match CStr::from_bytes_with_nul(&self.sql_bytes) {
253            Ok(cstr) => cstr.to_str().map_err(|err| {
254                Error::new(
255                    ErrorKind::InvalidInput,
256                    anyhow!(err).context("Invalid UTF-8 sequence"),
257                )
258            }),
259            Err(err) => Err(Error::new(
260                ErrorKind::InvalidInput,
261                anyhow!(err).context("Input end error"),
262            )),
263        }
264    }
265}
266
267impl FeCloseMessage {
268    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
269        let kind = buf.get_u8();
270        let name = read_null_terminated(&mut buf)?;
271        Ok(FeMessage::Close(FeCloseMessage { kind, name }))
272    }
273}
274
275#[derive(Clone)]
276pub struct FeMessageHeader {
277    pub tag: u8,
278    pub payload_len: i32,
279}
280
281impl FeMessage {
282    /// Read one message from the stream.
283    pub async fn read_header(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessageHeader> {
284        let tag = stream.read_u8().await?;
285        let len = stream.read_i32().await?;
286
287        let payload_len = len - 4;
288        Ok(FeMessageHeader { tag, payload_len })
289    }
290
291    /// Read one message from the stream.
292    pub async fn read_body(
293        stream: &mut (impl AsyncRead + Unpin),
294        header: FeMessageHeader,
295    ) -> Result<FeMessage> {
296        let FeMessageHeader { tag, payload_len } = header;
297        let mut payload: Vec<u8> = vec![0; payload_len as usize];
298        if payload_len > 0 {
299            stream.read_exact(&mut payload).await?;
300        }
301        let sql_bytes = Bytes::from(payload);
302        match tag {
303            b'Q' => Ok(FeMessage::Query(FeQueryMessage { sql_bytes })),
304            b'P' => FeParseMessage::parse(sql_bytes),
305            b'D' => FeDescribeMessage::parse(sql_bytes),
306            b'B' => FeBindMessage::parse(sql_bytes),
307            b'E' => FeExecuteMessage::parse(sql_bytes),
308            b'S' => Ok(FeMessage::Sync),
309            b'X' => Ok(FeMessage::Terminate),
310            b'C' => FeCloseMessage::parse(sql_bytes),
311            b'p' => FePasswordMessage::parse(sql_bytes),
312            b'H' => Ok(FeMessage::Flush),
313            _ => Err(std::io::Error::new(
314                ErrorKind::InvalidInput,
315                format!("Unsupported tag of regular message: {}", tag),
316            )),
317        }
318    }
319
320    pub async fn skip_body(
321        stream: &mut (impl AsyncRead + Unpin),
322        header: FeMessageHeader,
323    ) -> Result<()> {
324        let FeMessageHeader {
325            tag: _,
326            payload_len,
327        } = header;
328
329        if payload_len > 0 {
330            // Use smaller batches to process the payload instead of handling it all at once to minimize memory usage.
331            const BUF_SIZE: usize = 1024;
332            let mut buf: Vec<u8> = vec![0; BUF_SIZE];
333            for _ in 0..(payload_len as usize) / BUF_SIZE {
334                stream.read_exact(&mut buf).await?;
335            }
336            let remain = (payload_len as usize) % BUF_SIZE;
337            if remain > 0 {
338                buf.truncate(remain);
339                stream.read_exact(&mut buf).await?;
340            }
341        }
342        Ok(())
343    }
344}
345
346impl FeStartupMessage {
347    /// Read startup message from the stream.
348    pub async fn read(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessage> {
349        let mut stream = AsyncPeekable::new(stream);
350
351        if let Err(err) = stream.peek_exact(&mut [0; 1]).await {
352            // If the stream is empty, it can be a health check. Do not return error.
353            if err.kind() == ErrorKind::UnexpectedEof {
354                return Ok(FeMessage::HealthCheck);
355            } else {
356                return Err(err);
357            }
358        }
359
360        let len = stream.read_i32().await?;
361        let protocol_num = stream.read_i32().await?;
362        let payload_len = (len - 8) as usize;
363        if payload_len >= isize::MAX as usize {
364            return Err(std::io::Error::new(
365                ErrorKind::InvalidInput,
366                format!("Payload length has exceed usize::MAX {:?}", payload_len),
367            ));
368        }
369        let mut payload = vec![0; payload_len];
370        if payload_len > 0 {
371            stream.read_exact(&mut payload).await?;
372        }
373        match protocol_num {
374            // code from: https://www.postgresql.org/docs/current/protocol-message-formats.html
375            196608 => Ok(FeMessage::Startup(FeStartupMessage::build_with_payload(
376                &payload,
377            )?)),
378            80877104 => Ok(FeMessage::Gss),
379            80877103 => Ok(FeMessage::Ssl),
380            // Cancel request code.
381            80877102 => FeCancelMessage::parse(Bytes::from(payload)),
382            _ => Err(std::io::Error::new(
383                ErrorKind::InvalidInput,
384                format!(
385                    "Unsupported protocol number in start up msg {:?}",
386                    protocol_num
387                ),
388            )),
389        }
390    }
391}
392
393/// Continue read until reached a \0. Used in reading string from Bytes.
394fn read_null_terminated(buf: &mut Bytes) -> Result<Bytes> {
395    let mut result = BytesMut::new();
396
397    loop {
398        if !buf.has_remaining() {
399            panic!("no null-terminator in string");
400        }
401
402        let byte = buf.get_u8();
403
404        if byte == 0 {
405            break;
406        }
407        result.put_u8(byte);
408    }
409    Ok(result.freeze())
410}
411
412/// Message sent from server to psql client. Implement `write` (how to serialize it into psql
413/// buffer).
414/// Ref: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
415#[derive(Debug)]
416pub enum BeMessage<'a> {
417    AuthenticationOk,
418    AuthenticationCleartextPassword,
419    AuthenticationMd5Password(&'a [u8; 4]),
420    CommandComplete(BeCommandCompleteMessage),
421    NoticeResponse(&'a str),
422    // Single byte - used in response to SSLRequest/GSSENCRequest.
423    EncryptionResponseSsl,
424    EncryptionResponseGss,
425    EncryptionResponseNo,
426    EmptyQueryResponse,
427    ParseComplete,
428    BindComplete,
429    PortalSuspended,
430    // array of parameter oid(i32)
431    ParameterDescription(&'a [i32]),
432    NoData,
433    DataRow(&'a Row),
434    ParameterStatus(BeParameterStatusMessage<'a>),
435    ReadyForQuery(TransactionStatus),
436    RowDescription(&'a [PgFieldDescriptor]),
437    ErrorResponse(BoxedError),
438    CloseComplete,
439
440    // 0: process ID, 1: secret key
441    BackendKeyData((i32, i32)),
442}
443
444#[derive(Debug)]
445pub enum BeParameterStatusMessage<'a> {
446    ClientEncoding(&'a str),
447    StandardConformingString(&'a str),
448    ServerVersion(&'a str),
449    ApplicationName(&'a str),
450    TimeZone(&'a str),
451}
452
453#[derive(Debug)]
454pub struct BeCommandCompleteMessage {
455    pub stmt_type: StatementType,
456    pub rows_cnt: i32,
457}
458
459#[derive(Debug, Clone, Copy)]
460pub enum TransactionStatus {
461    Idle,
462    InTransaction,
463    InFailedTransaction,
464}
465
466impl BeMessage<'_> {
467    /// Write message to the given buf.
468    pub fn write(buf: &mut BytesMut, message: &BeMessage<'_>) -> Result<()> {
469        match message {
470            // AuthenticationOk
471            // +-----+----------+-----------+
472            // | 'R' | int32(8) | int32(0)  |
473            // +-----+----------+-----------+
474            BeMessage::AuthenticationOk => {
475                buf.put_u8(b'R');
476                buf.put_i32(8);
477                buf.put_i32(0);
478            }
479
480            // AuthenticationCleartextPassword
481            // +-----+----------+-----------+
482            // | 'R' | int32(8) | int32(3)  |
483            // +-----+----------+-----------+
484            BeMessage::AuthenticationCleartextPassword => {
485                buf.put_u8(b'R');
486                buf.put_i32(8);
487                buf.put_i32(3);
488            }
489
490            // AuthenticationMD5Password
491            // +-----+----------+-----------+----------------+
492            // | 'R' | int32(12) | int32(5)  |  Byte4(salt)  |
493            // +-----+----------+-----------+----------------+
494            //
495            // The 4-byte random salt will be used by client to send encrypted password as
496            // concat('md5', md5(concat(md5(concat(password, username)), random-salt))).
497            BeMessage::AuthenticationMd5Password(salt) => {
498                buf.put_u8(b'R');
499                buf.put_i32(12);
500                buf.put_i32(5);
501                buf.put_slice(&salt[..]);
502            }
503
504            // ParameterStatus
505            // +-----+-----------+----------+------+-----------+------+
506            // | 'S' | int32 len | str name | '\0' | str value | '\0' |
507            // +-----+-----------+----------+------+-----------+------+
508            BeMessage::ParameterStatus(param) => {
509                use BeParameterStatusMessage::*;
510                let [name, value] = match param {
511                    ClientEncoding(val) => [b"client_encoding", val.as_bytes()],
512                    StandardConformingString(val) => {
513                        [b"standard_conforming_strings", val.as_bytes()]
514                    }
515                    ServerVersion(val) => [b"server_version", val.as_bytes()],
516                    ApplicationName(val) => [b"application_name", val.as_bytes()],
517                    // psycopg3 is case-sensitive, so we use "TimeZone" instead of "timezone" #18079
518                    TimeZone(val) => [b"TimeZone", val.as_bytes()],
519                };
520
521                // Parameter names and values are passed as null-terminated strings
522                let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
523                let mut buffer = vec![];
524                let cnt = buffer.write_vectored(iov).unwrap();
525
526                buf.put_u8(b'S');
527                write_body(buf, |stream| {
528                    stream.put_slice(&buffer[..cnt]);
529                    Ok(())
530                })
531                .unwrap();
532            }
533
534            // CommandComplete
535            // +-----+-----------+-----------------+
536            // | 'C' | int32 len | str commandTag  |
537            // +-----+-----------+-----------------+
538            BeMessage::CommandComplete(cmd) => {
539                let rows_cnt = cmd.rows_cnt;
540                let mut stmt_type = cmd.stmt_type;
541                let mut tag = "".to_owned();
542                stmt_type = match stmt_type {
543                    StatementType::INSERT_RETURNING => StatementType::INSERT,
544                    StatementType::DELETE_RETURNING => StatementType::DELETE,
545                    StatementType::UPDATE_RETURNING => StatementType::UPDATE,
546                    s => s,
547                };
548                tag.push_str(&stmt_type.to_string());
549                if stmt_type == StatementType::INSERT {
550                    tag.push_str(" 0");
551                }
552                if stmt_type.is_command() {
553                    tag.push(' ');
554                    tag.push_str(&rows_cnt.to_string());
555                }
556                buf.put_u8(b'C');
557                write_body(buf, |buf| {
558                    write_cstr(buf, tag.as_bytes())?;
559                    Ok(())
560                })?;
561            }
562
563            // NoticeResponse
564            // +-----+-----------+------------------+------------------+
565            // | 'N' | int32 len | byte1 field type | str field value  |
566            // +-----+-----------+------------------+-+----------------+
567            // description of the fields can be found here:
568            // https://www.postgresql.org/docs/current/protocol-error-fields.html
569            BeMessage::NoticeResponse(notice) => {
570                buf.put_u8(b'N');
571                write_err_or_notice(buf, &ErrorOrNoticeMessage::notice(notice))?;
572            }
573
574            // DataRow
575            // +-----+-----------+--------------+--------+-----+--------+
576            // | 'D' | int32 len | int16 colNum | column | ... | column |
577            // +-----+-----------+--------------+----+---+-----+--------+
578            //                                       |
579            //                          +-----------+v------+
580            //                          | int32 len | bytes |
581            //                          +-----------+-------+
582            BeMessage::DataRow(vals) => {
583                buf.put_u8(b'D');
584                write_body(buf, |buf| {
585                    buf.put_u16(vals.len() as u16); // num of cols
586                    for val_opt in vals.values() {
587                        if let Some(val) = val_opt {
588                            buf.put_u32(val.len() as u32);
589                            buf.put_slice(val);
590                        } else {
591                            buf.put_i32(-1);
592                        }
593                    }
594                    Ok(())
595                })
596                .unwrap();
597            }
598
599            // RowDescription
600            // +-----+-----------+--------------+-------+-----+-------+
601            // | 'T' | int32 len | int16 colNum | field | ... | field |
602            // +-----+-----------+--------------+----+--+-----+-------+
603            //                                       |
604            // +---------------+-------+-------+-----v-+-------+-------+-------+
605            // | str fieldName | int32 | int16 | int32 | int16 | int32 | int16 |
606            // +---------------+---+---+---+---+---+---+----+--+---+---+---+---+
607            //                     |       |       |        |      |       |
608            //                     v       |       v        v      |       v
609            //                tableOID     |    typeOID  typeLen   |   formatCode
610            //                             v                       v
611            //                        colAttrNum               typeModifier
612            BeMessage::RowDescription(row_descs) => {
613                buf.put_u8(b'T');
614                write_body(buf, |buf| {
615                    buf.put_i16(row_descs.len() as i16); // # of fields
616                    for pg_field in *row_descs {
617                        write_cstr(buf, pg_field.get_name().as_bytes())?;
618                        buf.put_i32(pg_field.get_table_oid()); // table oid
619                        buf.put_i16(pg_field.get_col_attr_num()); // attnum
620                        buf.put_i32(pg_field.get_type_oid());
621                        buf.put_i16(pg_field.get_type_len());
622                        buf.put_i32(pg_field.get_type_modifier()); // typmod
623                        buf.put_i16(pg_field.get_format_code()); // format code
624                    }
625                    Ok(())
626                })?;
627            }
628            // ReadyForQuery
629            // +-----+----------+---------------------------+
630            // | 'Z' | int32(5) | byte1(transaction status) |
631            // +-----+----------+---------------------------+
632            BeMessage::ReadyForQuery(txn_status) => {
633                buf.put_u8(b'Z');
634                buf.put_i32(5);
635                // TODO: add transaction status
636                buf.put_u8(match txn_status {
637                    TransactionStatus::Idle => b'I',
638                    TransactionStatus::InTransaction => b'T',
639                    TransactionStatus::InFailedTransaction => b'E',
640                });
641            }
642
643            BeMessage::ParseComplete => {
644                buf.put_u8(b'1');
645                write_body(buf, |_| Ok(()))?;
646            }
647
648            BeMessage::BindComplete => {
649                buf.put_u8(b'2');
650                write_body(buf, |_| Ok(()))?;
651            }
652
653            BeMessage::CloseComplete => {
654                buf.put_u8(b'3');
655                write_body(buf, |_| Ok(()))?;
656            }
657
658            BeMessage::PortalSuspended => {
659                buf.put_u8(b's');
660                write_body(buf, |_| Ok(()))?;
661            }
662            // ParameterDescription
663            // +-----+-----------+--------------------+---------------+-----+---------------+
664            // | 't' | int32 len | int16 ParameterNum | int32 typeOID | ... | int32 typeOID |
665            // +-----+-----------+-----------------+--+---------------+-----+---------------+
666            BeMessage::ParameterDescription(para_descs) => {
667                buf.put_u8(b't');
668                write_body(buf, |buf| {
669                    buf.put_i16(para_descs.len() as i16);
670                    for oid in *para_descs {
671                        buf.put_i32(*oid);
672                    }
673                    Ok(())
674                })?;
675            }
676
677            BeMessage::NoData => {
678                buf.put_u8(b'n');
679                write_body(buf, |_| Ok(())).unwrap();
680            }
681
682            BeMessage::EncryptionResponseSsl => {
683                buf.put_u8(b'S');
684            }
685
686            BeMessage::EncryptionResponseGss => {
687                buf.put_u8(b'G');
688            }
689
690            BeMessage::EncryptionResponseNo => {
691                buf.put_u8(b'N');
692            }
693
694            // EmptyQueryResponse
695            // +-----+----------+
696            // | 'I' | int32(4) |
697            // +-----+----------+
698            BeMessage::EmptyQueryResponse => {
699                buf.put_u8(b'I');
700                buf.put_i32(4);
701            }
702
703            BeMessage::ErrorResponse(error) => {
704                // 'E' signalizes ErrorResponse messages
705                buf.put_u8(b'E');
706                // Format the error as a pretty report.
707                write_err_or_notice(buf, &ErrorOrNoticeMessage::error(error))?;
708            }
709
710            BeMessage::BackendKeyData((process_id, secret_key)) => {
711                buf.put_u8(b'K');
712                write_body(buf, |buf| {
713                    buf.put_i32(*process_id);
714                    buf.put_i32(*secret_key);
715                    Ok(())
716                })?;
717            }
718        }
719
720        Ok(())
721    }
722}
723
724// Safe usize -> i32|i16 conversion, from rust-postgres
725trait FromUsize: Sized {
726    fn from_usize(x: usize) -> Result<Self>;
727}
728
729macro_rules! from_usize {
730    ($t:ty) => {
731        impl FromUsize for $t {
732            #[inline]
733            fn from_usize(x: usize) -> Result<$t> {
734                if x > <$t>::MAX as usize {
735                    Err(Error::new(ErrorKind::InvalidInput, "value too large to transmit").into())
736                } else {
737                    Ok(x as $t)
738                }
739            }
740        }
741    };
742}
743
744from_usize!(i32);
745
746/// Call f() to write body of the message and prepend it with 4-byte len as
747/// prescribed by the protocol. First write out body value and fill length value as i32 in front of
748/// it.
749fn write_body<F>(buf: &mut BytesMut, f: F) -> Result<()>
750where
751    F: FnOnce(&mut BytesMut) -> Result<()>,
752{
753    let base = buf.len();
754    buf.extend_from_slice(&[0; 4]);
755
756    f(buf)?;
757
758    let size = i32::from_usize(buf.len() - base)?;
759    BigEndian::write_i32(&mut buf[base..], size);
760    Ok(())
761}
762
763/// Safe write of s into buf as cstring (String in the protocol).
764fn write_cstr(buf: &mut BytesMut, s: &[u8]) -> Result<()> {
765    if s.contains(&0) {
766        return Err(Error::new(
767            ErrorKind::InvalidInput,
768            "string contains embedded null",
769        ));
770    }
771    buf.put_slice(s);
772    buf.put_u8(0);
773    Ok(())
774}
775
776/// Safe write error or notice message.
777fn write_err_or_notice(buf: &mut BytesMut, msg: &ErrorOrNoticeMessage<'_>) -> Result<()> {
778    write_body(buf, |buf| {
779        buf.put_u8(b'S'); // severity
780        write_cstr(buf, msg.severity.as_str().as_bytes())?;
781
782        buf.put_u8(b'C'); // SQLSTATE error code
783        write_cstr(buf, msg.error_code.sqlstate().as_bytes())?;
784
785        buf.put_u8(b'M'); // the message
786        write_cstr(buf, msg.message.as_bytes())?;
787
788        buf.put_u8(0); // terminator
789        Ok(())
790    })
791}
792
793#[cfg(test)]
794mod tests {
795    use bytes::Bytes;
796
797    use crate::pg_message::FeQueryMessage;
798
799    #[test]
800    fn test_get_sql() {
801        let fe = FeQueryMessage {
802            sql_bytes: Bytes::from(vec![255, 255, 255, 255, 255, 255, 0]),
803        };
804        assert!(fe.get_sql().is_err(), "{}", true);
805        let fe = FeQueryMessage {
806            sql_bytes: Bytes::from(vec![1, 2, 3, 4, 5, 6, 7, 8]),
807        };
808        assert!(fe.get_sql().is_err(), "{}", true);
809    }
810}