pgwire/
types.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::iter::TrustedLen;
16use std::ops::Index;
17use std::slice::Iter;
18
19use bytes::Bytes;
20
21use crate::error::{PsqlError, PsqlResult};
22
23/// A row of data returned from the database by a query.
24#[derive(Debug, Clone)]
25// NOTE: Since we only support simple query protocol, the values are represented as strings.
26pub struct Row(Vec<Option<Bytes>>);
27
28impl Row {
29    /// Create a row from values.
30    pub fn new(row: Vec<Option<Bytes>>) -> Self {
31        Self(row)
32    }
33
34    /// Returns the number of values in the row.
35    pub fn len(&self) -> usize {
36        self.0.len()
37    }
38
39    /// Returns `true` if the row contains no values. Required by clippy.
40    pub fn is_empty(&self) -> bool {
41        self.0.is_empty()
42    }
43
44    /// Returns the values.
45    pub fn values(&self) -> &[Option<Bytes>] {
46        &self.0
47    }
48
49    pub fn take(self) -> Vec<Option<Bytes>> {
50        self.0
51    }
52
53    pub fn project(&mut self, indices: &[usize]) -> Row {
54        let mut new_row = Vec::with_capacity(indices.len());
55        for i in indices {
56            new_row.push(self.0[*i].take());
57        }
58        Row(new_row)
59    }
60}
61
62impl Index<usize> for Row {
63    type Output = Option<Bytes>;
64
65    fn index(&self, index: usize) -> &Self::Output {
66        &self.0[index]
67    }
68}
69
70/// <https://www.postgresql.org/docs/current/protocol-overview.html#PROTOCOL-FORMAT-CODES>
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum Format {
73    Binary,
74    Text,
75}
76
77impl Format {
78    pub fn from_i16(format_code: i16) -> PsqlResult<Self> {
79        match format_code {
80            0 => Ok(Format::Text),
81            1 => Ok(Format::Binary),
82            _ => Err(PsqlError::Uncategorized(
83                format!("Unknown format code: {}", format_code).into(),
84            )),
85        }
86    }
87}
88
89/// FormatIterator used to generate formats of actual length given the provided format.
90/// According Postgres Document: <https://www.postgresql.org/docs/current/protocol-message-formats.html#:~:text=The%20number%20of,number%20of%20parameters>
91/// - If the length of provided format is 0, all format will be default format(TEXT).
92/// - If the length of provided format is 1, all format will be the same as this only format.
93/// - If the length of provided format > 1, provided format should be the actual format.
94#[derive(Debug, Clone)]
95pub struct FormatIterator<'a, 'b>
96where
97    'a: 'b,
98{
99    _formats: &'a [Format],
100    format_iter: Iter<'b, Format>,
101    actual_len: usize,
102    default_format: Format,
103}
104
105impl<'a> FormatIterator<'a, '_> {
106    pub fn new(provided_formats: &'a [Format], actual_len: usize) -> Result<Self, String> {
107        if !provided_formats.is_empty()
108            && provided_formats.len() != 1
109            && provided_formats.len() != actual_len
110        {
111            return Err(format!(
112                "format codes length {} is not 0, 1 or equal to actual length {}",
113                provided_formats.len(),
114                actual_len
115            ));
116        }
117
118        let default_format = provided_formats.first().copied().unwrap_or(Format::Text);
119
120        Ok(Self {
121            _formats: provided_formats,
122            default_format,
123            format_iter: provided_formats.iter(),
124            actual_len,
125        })
126    }
127}
128
129impl Iterator for FormatIterator<'_, '_> {
130    type Item = Format;
131
132    fn next(&mut self) -> Option<Self::Item> {
133        if self.actual_len == 0 {
134            return None;
135        }
136
137        self.actual_len -= 1;
138
139        Some(
140            self.format_iter
141                .next()
142                .copied()
143                .unwrap_or(self.default_format),
144        )
145    }
146
147    fn size_hint(&self) -> (usize, Option<usize>) {
148        (self.actual_len, Some(self.actual_len))
149    }
150}
151
152impl ExactSizeIterator for FormatIterator<'_, '_> {}
153unsafe impl TrustedLen for FormatIterator<'_, '_> {}