pgwire/
pg_message.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ffi::CStr;
17use std::io::{Error, ErrorKind, IoSlice, Result, Write};
18
19use anyhow::anyhow;
20use byteorder::{BigEndian, ByteOrder};
21/// Part of code learned from <https://github.com/zenithdb/zenith/blob/main/zenith_utils/src/pq_proto.rs>.
22use bytes::{Buf, BufMut, Bytes, BytesMut};
23use peekable::tokio::AsyncPeekable;
24use tokio::io::{AsyncRead, AsyncReadExt};
25
26use crate::error_or_notice::ErrorOrNoticeMessage;
27use crate::pg_field_descriptor::PgFieldDescriptor;
28use crate::pg_response::StatementType;
29use crate::pg_server::BoxedError;
30use crate::types::Row;
31
32/// Messages that can be sent from pg client to server. Implement `read`.
33#[derive(Debug)]
34pub enum FeMessage {
35    Ssl,
36    Gss,
37    Startup(FeStartupMessage),
38    Query(FeQueryMessage),
39    Parse(FeParseMessage),
40    Password(FePasswordMessage),
41    Describe(FeDescribeMessage),
42    Bind(FeBindMessage),
43    Execute(FeExecuteMessage),
44    Close(FeCloseMessage),
45    Sync,
46    CancelQuery(FeCancelMessage),
47    Terminate,
48    Flush,
49    // special msg to detect health check, which represents the client immediately closes the connection cleanly without sending any data.
50    HealthCheck,
51    // The original message has been rejected due to server throttling. This is a placeholder message generated by server.
52    ServerThrottle(ServerThrottleReason),
53}
54
55#[derive(Debug)]
56pub enum ServerThrottleReason {
57    TooLargeMessage,
58    TooManyMemoryUsage,
59}
60
61#[derive(Debug)]
62pub struct FeStartupMessage {
63    pub config: HashMap<String, String>,
64}
65
66impl FeStartupMessage {
67    pub fn build_with_payload(payload: &[u8]) -> Result<Self> {
68        let config = match std::str::from_utf8(payload) {
69            Ok(v) => Ok(v.trim_end_matches('\0')),
70            Err(err) => Err(Error::new(
71                ErrorKind::InvalidInput,
72                anyhow!(err).context("Input end error"),
73            )),
74        }?;
75        let mut map = HashMap::new();
76        let config: Vec<&str> = config.split('\0').collect();
77        if config.len() % 2 == 1 {
78            return Err(Error::new(
79                ErrorKind::InvalidInput,
80                "Invalid input config: odd number of config pairs",
81            ));
82        }
83        config.chunks(2).for_each(|chunk| {
84            map.insert(chunk[0].to_owned(), chunk[1].to_owned());
85        });
86        Ok(FeStartupMessage { config: map })
87    }
88}
89
90/// Query message contains the string sql.
91#[derive(Debug)]
92pub struct FeQueryMessage {
93    pub sql_bytes: Bytes,
94}
95
96#[derive(Debug)]
97pub struct FeBindMessage {
98    pub param_format_codes: Vec<i16>,
99    pub result_format_codes: Vec<i16>,
100
101    pub params: Vec<Option<Bytes>>,
102    pub portal_name: Bytes,
103    pub statement_name: Bytes,
104}
105
106#[derive(Debug)]
107pub struct FeExecuteMessage {
108    pub portal_name: Bytes,
109    pub max_rows: i32,
110}
111
112#[derive(Debug)]
113pub struct FeParseMessage {
114    pub statement_name: Bytes,
115    pub sql_bytes: Bytes,
116    pub type_ids: Vec<i32>,
117}
118
119#[derive(Debug)]
120pub struct FePasswordMessage {
121    pub password: Bytes,
122}
123
124#[derive(Debug)]
125pub struct FeDescribeMessage {
126    // 'S' to describe a prepared statement; or 'P' to describe a portal.
127    pub kind: u8,
128    pub name: Bytes,
129}
130
131#[derive(Debug)]
132pub struct FeCloseMessage {
133    pub kind: u8,
134    pub name: Bytes,
135}
136
137#[derive(Debug)]
138pub struct FeCancelMessage {
139    pub target_process_id: i32,
140    pub target_secret_key: i32,
141}
142
143impl FeCancelMessage {
144    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
145        let target_process_id = buf.get_i32();
146        let target_secret_key = buf.get_i32();
147        Ok(FeMessage::CancelQuery(Self {
148            target_process_id,
149            target_secret_key,
150        }))
151    }
152}
153
154impl FeDescribeMessage {
155    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
156        let kind = buf.get_u8();
157        let name = read_null_terminated(&mut buf)?;
158
159        Ok(FeMessage::Describe(FeDescribeMessage { kind, name }))
160    }
161}
162
163impl FeBindMessage {
164    // Bind Message Header
165    // +-----+-----------+
166    // | 'B' | int32 len |
167    // +-----+-----------+
168    // Bind Message Body
169    // +----------------+---------------+
170    // | str portalname | str statement |
171    // +----------------+---------------+
172    // +---------------------+------------------+-------+
173    // | int16 numFormatCode | int16 FormatCode |  ...  |
174    // +---------------------+------------------+-------+
175    // +-----------------+-------------------+---------------+
176    // | int16 numParams | int32 valueLength |  byte value.. |
177    // +-----------------+-------------------+---------------+
178    // +----------------------------------+------------------+-------+
179    // | int16 numResultColumnFormatCodes | int16 FormatCode |  ...  |
180    // +----------------------------------+------------------+-------+
181    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
182        let portal_name = read_null_terminated(&mut buf)?;
183        let statement_name = read_null_terminated(&mut buf)?;
184
185        let len = buf.get_i16();
186        let param_format_codes = (0..len).map(|_| buf.get_i16()).collect();
187
188        // Read Params
189        let len = buf.get_i16();
190        let params = (0..len)
191            .map(|_| {
192                let val_len = buf.get_i32();
193                if val_len == -1 {
194                    None
195                } else {
196                    Some(buf.copy_to_bytes(val_len as usize))
197                }
198            })
199            .collect();
200
201        let len = buf.get_i16();
202        let result_format_codes = (0..len).map(|_| buf.get_i16()).collect();
203
204        Ok(FeMessage::Bind(FeBindMessage {
205            param_format_codes,
206            result_format_codes,
207            params,
208            portal_name,
209            statement_name,
210        }))
211    }
212}
213
214impl FeExecuteMessage {
215    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
216        let portal_name = read_null_terminated(&mut buf)?;
217        let max_rows = buf.get_i32();
218
219        Ok(FeMessage::Execute(FeExecuteMessage {
220            portal_name,
221            max_rows,
222        }))
223    }
224}
225
226impl FeParseMessage {
227    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
228        let statement_name = read_null_terminated(&mut buf)?;
229        let sql_bytes = read_null_terminated(&mut buf)?;
230        let nparams = buf.get_i16();
231
232        let type_ids: Vec<i32> = (0..nparams).map(|_| buf.get_i32()).collect();
233
234        Ok(FeMessage::Parse(FeParseMessage {
235            statement_name,
236            sql_bytes,
237            type_ids,
238        }))
239    }
240}
241
242impl FePasswordMessage {
243    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
244        let password = read_null_terminated(&mut buf)?;
245
246        Ok(FeMessage::Password(FePasswordMessage { password }))
247    }
248}
249
250impl FeQueryMessage {
251    pub fn get_sql(&self) -> Result<&str> {
252        match CStr::from_bytes_with_nul(&self.sql_bytes) {
253            Ok(cstr) => cstr.to_str().map_err(|err| {
254                Error::new(
255                    ErrorKind::InvalidInput,
256                    anyhow!(err).context("Invalid UTF-8 sequence"),
257                )
258            }),
259            Err(err) => Err(Error::new(
260                ErrorKind::InvalidInput,
261                anyhow!(err).context("Input end error"),
262            )),
263        }
264    }
265}
266
267impl FeCloseMessage {
268    pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
269        let kind = buf.get_u8();
270        let name = read_null_terminated(&mut buf)?;
271        Ok(FeMessage::Close(FeCloseMessage { kind, name }))
272    }
273}
274
275#[derive(Clone)]
276pub struct FeMessageHeader {
277    pub tag: u8,
278    pub payload_len: i32,
279}
280
281impl FeMessage {
282    /// Read one message from the stream.
283    pub async fn read_header(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessageHeader> {
284        let tag = stream.read_u8().await?;
285        let len = stream.read_i32().await?;
286
287        let payload_len = len - 4;
288        Ok(FeMessageHeader { tag, payload_len })
289    }
290
291    /// Read one message from the stream.
292    pub async fn read_body(
293        stream: &mut (impl AsyncRead + Unpin),
294        header: FeMessageHeader,
295    ) -> Result<FeMessage> {
296        let FeMessageHeader { tag, payload_len } = header;
297        let mut payload: Vec<u8> = vec![0; payload_len as usize];
298        if payload_len > 0 {
299            stream.read_exact(&mut payload).await?;
300        }
301        let sql_bytes = Bytes::from(payload);
302        match tag {
303            b'Q' => Ok(FeMessage::Query(FeQueryMessage { sql_bytes })),
304            b'P' => FeParseMessage::parse(sql_bytes),
305            b'D' => FeDescribeMessage::parse(sql_bytes),
306            b'B' => FeBindMessage::parse(sql_bytes),
307            b'E' => FeExecuteMessage::parse(sql_bytes),
308            b'S' => Ok(FeMessage::Sync),
309            b'X' => Ok(FeMessage::Terminate),
310            b'C' => FeCloseMessage::parse(sql_bytes),
311            b'p' => FePasswordMessage::parse(sql_bytes),
312            b'H' => Ok(FeMessage::Flush),
313            _ => Err(std::io::Error::new(
314                ErrorKind::InvalidInput,
315                format!("Unsupported tag of regular message: {}", tag),
316            )),
317        }
318    }
319
320    pub async fn skip_body(
321        stream: &mut (impl AsyncRead + Unpin),
322        header: FeMessageHeader,
323    ) -> Result<()> {
324        let FeMessageHeader {
325            tag: _,
326            payload_len,
327        } = header;
328
329        if payload_len > 0 {
330            // Use smaller batches to process the payload instead of handling it all at once to minimize memory usage.
331            const BUF_SIZE: usize = 1024;
332            let mut buf: Vec<u8> = vec![0; BUF_SIZE];
333            for _ in 0..(payload_len as usize) / BUF_SIZE {
334                stream.read_exact(&mut buf).await?;
335            }
336            let remain = (payload_len as usize) % BUF_SIZE;
337            if remain > 0 {
338                buf.truncate(remain);
339                stream.read_exact(&mut buf).await?;
340            }
341        }
342        Ok(())
343    }
344}
345
346impl FeStartupMessage {
347    /// Read startup message from the stream.
348    pub async fn read(stream: &mut (impl AsyncRead + Unpin)) -> Result<FeMessage> {
349        let mut stream = AsyncPeekable::new(stream);
350
351        if let Err(err) = stream.peek_exact(&mut [0; 1]).await {
352            // If the stream is empty, it can be a health check. Do not return error.
353            if err.kind() == ErrorKind::UnexpectedEof {
354                return Ok(FeMessage::HealthCheck);
355            } else {
356                return Err(err);
357            }
358        }
359
360        let len = stream.read_i32().await?;
361        let protocol_num = stream.read_i32().await?;
362        let payload_len = (len - 8) as usize;
363        if payload_len >= isize::MAX as usize {
364            return Err(std::io::Error::new(
365                ErrorKind::InvalidInput,
366                format!("Payload length has exceed usize::MAX {:?}", payload_len),
367            ));
368        }
369        let mut payload = vec![0; payload_len];
370        if payload_len > 0 {
371            stream.read_exact(&mut payload).await?;
372        }
373        match protocol_num {
374            // code from: https://www.postgresql.org/docs/current/protocol-message-formats.html
375            196608 => Ok(FeMessage::Startup(FeStartupMessage::build_with_payload(
376                &payload,
377            )?)),
378            80877104 => Ok(FeMessage::Gss),
379            80877103 => Ok(FeMessage::Ssl),
380            // Cancel request code.
381            80877102 => FeCancelMessage::parse(Bytes::from(payload)),
382            _ => Err(std::io::Error::new(
383                ErrorKind::InvalidInput,
384                format!(
385                    "Unsupported protocol number in start up msg {:?}",
386                    protocol_num
387                ),
388            )),
389        }
390    }
391}
392
393/// Continue read until reached a \0. Used in reading string from Bytes.
394fn read_null_terminated(buf: &mut Bytes) -> Result<Bytes> {
395    let mut result = BytesMut::new();
396
397    loop {
398        if !buf.has_remaining() {
399            panic!("no null-terminator in string");
400        }
401
402        let byte = buf.get_u8();
403
404        if byte == 0 {
405            break;
406        }
407        result.put_u8(byte);
408    }
409    Ok(result.freeze())
410}
411
412/// Message sent from server to psql client. Implement `write` (how to serialize it into psql
413/// buffer).
414/// Ref: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
415#[derive(Debug)]
416pub enum BeMessage<'a> {
417    AuthenticationOk,
418    AuthenticationCleartextPassword,
419    AuthenticationMd5Password(&'a [u8; 4]),
420    CommandComplete(BeCommandCompleteMessage),
421    NoticeResponse(&'a str),
422    // Single byte - used in response to SSLRequest/GSSENCRequest.
423    EncryptionResponseSsl,
424    EncryptionResponseGss,
425    EncryptionResponseNo,
426    EmptyQueryResponse,
427    ParseComplete,
428    BindComplete,
429    PortalSuspended,
430    // array of parameter oid(i32)
431    ParameterDescription(&'a [i32]),
432    NoData,
433    DataRow(&'a Row),
434    ParameterStatus(BeParameterStatusMessage<'a>),
435    ReadyForQuery(TransactionStatus),
436    RowDescription(&'a [PgFieldDescriptor]),
437    ErrorResponse(BoxedError),
438    CloseComplete,
439
440    // 0: process ID, 1: secret key
441    BackendKeyData((i32, i32)),
442}
443
444#[derive(Debug)]
445pub enum BeParameterStatusMessage<'a> {
446    ClientEncoding(&'a str),
447    StandardConformingString(&'a str),
448    ServerVersion(&'a str),
449    ApplicationName(&'a str),
450}
451
452#[derive(Debug)]
453pub struct BeCommandCompleteMessage {
454    pub stmt_type: StatementType,
455    pub rows_cnt: i32,
456}
457
458#[derive(Debug, Clone, Copy)]
459pub enum TransactionStatus {
460    Idle,
461    InTransaction,
462    InFailedTransaction,
463}
464
465impl BeMessage<'_> {
466    /// Write message to the given buf.
467    pub fn write(buf: &mut BytesMut, message: &BeMessage<'_>) -> Result<()> {
468        match message {
469            // AuthenticationOk
470            // +-----+----------+-----------+
471            // | 'R' | int32(8) | int32(0)  |
472            // +-----+----------+-----------+
473            BeMessage::AuthenticationOk => {
474                buf.put_u8(b'R');
475                buf.put_i32(8);
476                buf.put_i32(0);
477            }
478
479            // AuthenticationCleartextPassword
480            // +-----+----------+-----------+
481            // | 'R' | int32(8) | int32(3)  |
482            // +-----+----------+-----------+
483            BeMessage::AuthenticationCleartextPassword => {
484                buf.put_u8(b'R');
485                buf.put_i32(8);
486                buf.put_i32(3);
487            }
488
489            // AuthenticationMD5Password
490            // +-----+----------+-----------+----------------+
491            // | 'R' | int32(12) | int32(5)  |  Byte4(salt)  |
492            // +-----+----------+-----------+----------------+
493            //
494            // The 4-byte random salt will be used by client to send encrypted password as
495            // concat('md5', md5(concat(md5(concat(password, username)), random-salt))).
496            BeMessage::AuthenticationMd5Password(salt) => {
497                buf.put_u8(b'R');
498                buf.put_i32(12);
499                buf.put_i32(5);
500                buf.put_slice(&salt[..]);
501            }
502
503            // ParameterStatus
504            // +-----+-----------+----------+------+-----------+------+
505            // | 'S' | int32 len | str name | '\0' | str value | '\0' |
506            // +-----+-----------+----------+------+-----------+------+
507            BeMessage::ParameterStatus(param) => {
508                use BeParameterStatusMessage::*;
509                let [name, value] = match param {
510                    ClientEncoding(val) => [b"client_encoding", val.as_bytes()],
511                    StandardConformingString(val) => {
512                        [b"standard_conforming_strings", val.as_bytes()]
513                    }
514                    ServerVersion(val) => [b"server_version", val.as_bytes()],
515                    ApplicationName(val) => [b"application_name", val.as_bytes()],
516                };
517
518                // Parameter names and values are passed as null-terminated strings
519                let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
520                let mut buffer = vec![];
521                let cnt = buffer.write_vectored(iov).unwrap();
522
523                buf.put_u8(b'S');
524                write_body(buf, |stream| {
525                    stream.put_slice(&buffer[..cnt]);
526                    Ok(())
527                })
528                .unwrap();
529            }
530
531            // CommandComplete
532            // +-----+-----------+-----------------+
533            // | 'C' | int32 len | str commandTag  |
534            // +-----+-----------+-----------------+
535            BeMessage::CommandComplete(cmd) => {
536                let rows_cnt = cmd.rows_cnt;
537                let mut stmt_type = cmd.stmt_type;
538                let mut tag = "".to_owned();
539                stmt_type = match stmt_type {
540                    StatementType::INSERT_RETURNING => StatementType::INSERT,
541                    StatementType::DELETE_RETURNING => StatementType::DELETE,
542                    StatementType::UPDATE_RETURNING => StatementType::UPDATE,
543                    s => s,
544                };
545                tag.push_str(&stmt_type.to_string());
546                if stmt_type == StatementType::INSERT {
547                    tag.push_str(" 0");
548                }
549                if stmt_type.is_command() {
550                    tag.push(' ');
551                    tag.push_str(&rows_cnt.to_string());
552                }
553                buf.put_u8(b'C');
554                write_body(buf, |buf| {
555                    write_cstr(buf, tag.as_bytes())?;
556                    Ok(())
557                })?;
558            }
559
560            // NoticeResponse
561            // +-----+-----------+------------------+------------------+
562            // | 'N' | int32 len | byte1 field type | str field value  |
563            // +-----+-----------+------------------+-+----------------+
564            // description of the fields can be found here:
565            // https://www.postgresql.org/docs/current/protocol-error-fields.html
566            BeMessage::NoticeResponse(notice) => {
567                buf.put_u8(b'N');
568                write_err_or_notice(buf, &ErrorOrNoticeMessage::notice(notice))?;
569            }
570
571            // DataRow
572            // +-----+-----------+--------------+--------+-----+--------+
573            // | 'D' | int32 len | int16 colNum | column | ... | column |
574            // +-----+-----------+--------------+----+---+-----+--------+
575            //                                       |
576            //                          +-----------+v------+
577            //                          | int32 len | bytes |
578            //                          +-----------+-------+
579            BeMessage::DataRow(vals) => {
580                buf.put_u8(b'D');
581                write_body(buf, |buf| {
582                    buf.put_u16(vals.len() as u16); // num of cols
583                    for val_opt in vals.values() {
584                        if let Some(val) = val_opt {
585                            buf.put_u32(val.len() as u32);
586                            buf.put_slice(val);
587                        } else {
588                            buf.put_i32(-1);
589                        }
590                    }
591                    Ok(())
592                })
593                .unwrap();
594            }
595
596            // RowDescription
597            // +-----+-----------+--------------+-------+-----+-------+
598            // | 'T' | int32 len | int16 colNum | field | ... | field |
599            // +-----+-----------+--------------+----+--+-----+-------+
600            //                                       |
601            // +---------------+-------+-------+-----v-+-------+-------+-------+
602            // | str fieldName | int32 | int16 | int32 | int16 | int32 | int16 |
603            // +---------------+---+---+---+---+---+---+----+--+---+---+---+---+
604            //                     |       |       |        |      |       |
605            //                     v       |       v        v      |       v
606            //                tableOID     |    typeOID  typeLen   |   formatCode
607            //                             v                       v
608            //                        colAttrNum               typeModifier
609            BeMessage::RowDescription(row_descs) => {
610                buf.put_u8(b'T');
611                write_body(buf, |buf| {
612                    buf.put_i16(row_descs.len() as i16); // # of fields
613                    for pg_field in *row_descs {
614                        write_cstr(buf, pg_field.get_name().as_bytes())?;
615                        buf.put_i32(pg_field.get_table_oid()); // table oid
616                        buf.put_i16(pg_field.get_col_attr_num()); // attnum
617                        buf.put_i32(pg_field.get_type_oid());
618                        buf.put_i16(pg_field.get_type_len());
619                        buf.put_i32(pg_field.get_type_modifier()); // typmod
620                        buf.put_i16(pg_field.get_format_code()); // format code
621                    }
622                    Ok(())
623                })?;
624            }
625            // ReadyForQuery
626            // +-----+----------+---------------------------+
627            // | 'Z' | int32(5) | byte1(transaction status) |
628            // +-----+----------+---------------------------+
629            BeMessage::ReadyForQuery(txn_status) => {
630                buf.put_u8(b'Z');
631                buf.put_i32(5);
632                // TODO: add transaction status
633                buf.put_u8(match txn_status {
634                    TransactionStatus::Idle => b'I',
635                    TransactionStatus::InTransaction => b'T',
636                    TransactionStatus::InFailedTransaction => b'E',
637                });
638            }
639
640            BeMessage::ParseComplete => {
641                buf.put_u8(b'1');
642                write_body(buf, |_| Ok(()))?;
643            }
644
645            BeMessage::BindComplete => {
646                buf.put_u8(b'2');
647                write_body(buf, |_| Ok(()))?;
648            }
649
650            BeMessage::CloseComplete => {
651                buf.put_u8(b'3');
652                write_body(buf, |_| Ok(()))?;
653            }
654
655            BeMessage::PortalSuspended => {
656                buf.put_u8(b's');
657                write_body(buf, |_| Ok(()))?;
658            }
659            // ParameterDescription
660            // +-----+-----------+--------------------+---------------+-----+---------------+
661            // | 't' | int32 len | int16 ParameterNum | int32 typeOID | ... | int32 typeOID |
662            // +-----+-----------+-----------------+--+---------------+-----+---------------+
663            BeMessage::ParameterDescription(para_descs) => {
664                buf.put_u8(b't');
665                write_body(buf, |buf| {
666                    buf.put_i16(para_descs.len() as i16);
667                    for oid in *para_descs {
668                        buf.put_i32(*oid);
669                    }
670                    Ok(())
671                })?;
672            }
673
674            BeMessage::NoData => {
675                buf.put_u8(b'n');
676                write_body(buf, |_| Ok(())).unwrap();
677            }
678
679            BeMessage::EncryptionResponseSsl => {
680                buf.put_u8(b'S');
681            }
682
683            BeMessage::EncryptionResponseGss => {
684                buf.put_u8(b'G');
685            }
686
687            BeMessage::EncryptionResponseNo => {
688                buf.put_u8(b'N');
689            }
690
691            // EmptyQueryResponse
692            // +-----+----------+
693            // | 'I' | int32(4) |
694            // +-----+----------+
695            BeMessage::EmptyQueryResponse => {
696                buf.put_u8(b'I');
697                buf.put_i32(4);
698            }
699
700            BeMessage::ErrorResponse(error) => {
701                use thiserror_ext::AsReport;
702                // For all the errors set Severity to Error and error code to
703                // 'internal error'.
704
705                // 'E' signalizes ErrorResponse messages
706                buf.put_u8(b'E');
707                // Format the error as a pretty report.
708                let msg = error.to_report_string_pretty();
709                write_err_or_notice(buf, &ErrorOrNoticeMessage::internal_error(&msg))?;
710            }
711
712            BeMessage::BackendKeyData((process_id, secret_key)) => {
713                buf.put_u8(b'K');
714                write_body(buf, |buf| {
715                    buf.put_i32(*process_id);
716                    buf.put_i32(*secret_key);
717                    Ok(())
718                })?;
719            }
720        }
721
722        Ok(())
723    }
724}
725
726// Safe usize -> i32|i16 conversion, from rust-postgres
727trait FromUsize: Sized {
728    fn from_usize(x: usize) -> Result<Self>;
729}
730
731macro_rules! from_usize {
732    ($t:ty) => {
733        impl FromUsize for $t {
734            #[inline]
735            fn from_usize(x: usize) -> Result<$t> {
736                if x > <$t>::MAX as usize {
737                    Err(Error::new(ErrorKind::InvalidInput, "value too large to transmit").into())
738                } else {
739                    Ok(x as $t)
740                }
741            }
742        }
743    };
744}
745
746from_usize!(i32);
747
748/// Call f() to write body of the message and prepend it with 4-byte len as
749/// prescribed by the protocol. First write out body value and fill length value as i32 in front of
750/// it.
751fn write_body<F>(buf: &mut BytesMut, f: F) -> Result<()>
752where
753    F: FnOnce(&mut BytesMut) -> Result<()>,
754{
755    let base = buf.len();
756    buf.extend_from_slice(&[0; 4]);
757
758    f(buf)?;
759
760    let size = i32::from_usize(buf.len() - base)?;
761    BigEndian::write_i32(&mut buf[base..], size);
762    Ok(())
763}
764
765/// Safe write of s into buf as cstring (String in the protocol).
766fn write_cstr(buf: &mut BytesMut, s: &[u8]) -> Result<()> {
767    if s.contains(&0) {
768        return Err(Error::new(
769            ErrorKind::InvalidInput,
770            "string contains embedded null",
771        ));
772    }
773    buf.put_slice(s);
774    buf.put_u8(0);
775    Ok(())
776}
777
778/// Safe write error or notice message.
779fn write_err_or_notice(buf: &mut BytesMut, msg: &ErrorOrNoticeMessage<'_>) -> Result<()> {
780    write_body(buf, |buf| {
781        buf.put_u8(b'S'); // severity
782        write_cstr(buf, msg.severity.as_str().as_bytes())?;
783
784        buf.put_u8(b'C'); // SQLSTATE error code
785        write_cstr(buf, msg.state.code().as_bytes())?;
786
787        buf.put_u8(b'M'); // the message
788        write_cstr(buf, msg.message.as_bytes())?;
789
790        buf.put_u8(0); // terminator
791        Ok(())
792    })
793}
794
795#[cfg(test)]
796mod tests {
797    use bytes::Bytes;
798
799    use crate::pg_message::FeQueryMessage;
800
801    #[test]
802    fn test_get_sql() {
803        let fe = FeQueryMessage {
804            sql_bytes: Bytes::from(vec![255, 255, 255, 255, 255, 255, 0]),
805        };
806        assert!(fe.get_sql().is_err(), "{}", true);
807        let fe = FeQueryMessage {
808            sql_bytes: Bytes::from(vec![1, 2, 3, 4, 5, 6, 7, 8]),
809        };
810        assert!(fe.get_sql().is_err(), "{}", true);
811    }
812}