pgwire/
types.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::iter::TrustedLen;
16use std::ops::Index;
17use std::slice::Iter;
18
19use bytes::Bytes;
20
21use crate::error::{PsqlError, PsqlResult};
22
23/// A row of data returned from the database by a query.
24#[derive(Debug, Clone)]
25// NOTE: Since we only support simple query protocol, the values are represented as strings.
26pub struct Row(Vec<Option<Bytes>>);
27
28impl Row {
29    /// Create a row from values.
30    pub fn new(row: Vec<Option<Bytes>>) -> Self {
31        Self(row)
32    }
33
34    /// Returns the number of values in the row.
35    pub fn len(&self) -> usize {
36        self.0.len()
37    }
38
39    /// Returns `true` if the row contains no values. Required by clippy.
40    pub fn is_empty(&self) -> bool {
41        self.0.is_empty()
42    }
43
44    /// Returns the values.
45    pub fn values(&self) -> &[Option<Bytes>] {
46        &self.0
47    }
48
49    pub fn take(self) -> Vec<Option<Bytes>> {
50        self.0
51    }
52
53    pub fn project(&mut self, indices: &[usize]) -> Row {
54        let mut new_row = Vec::with_capacity(indices.len());
55        for i in indices {
56            new_row.push(self.0[*i].take());
57        }
58        Row(new_row)
59    }
60}
61
62impl Index<usize> for Row {
63    type Output = Option<Bytes>;
64
65    fn index(&self, index: usize) -> &Self::Output {
66        &self.0[index]
67    }
68}
69
70/// <https://www.postgresql.org/docs/current/protocol-overview.html#PROTOCOL-FORMAT-CODES>
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum Format {
73    Binary,
74    Text,
75}
76
77impl Format {
78    pub fn from_i16(format_code: i16) -> PsqlResult<Self> {
79        match format_code {
80            0 => Ok(Format::Text),
81            1 => Ok(Format::Binary),
82            _ => Err(PsqlError::Uncategorized(
83                format!("Unknown format code: {}", format_code).into(),
84            )),
85        }
86    }
87
88    pub fn to_i8(self) -> i8 {
89        match self {
90            Format::Binary => 1,
91            Format::Text => 0,
92        }
93    }
94}
95
96/// FormatIterator used to generate formats of actual length given the provided format.
97/// According Postgres Document: <https://www.postgresql.org/docs/current/protocol-message-formats.html#:~:text=The%20number%20of,number%20of%20parameters>
98/// - If the length of provided format is 0, all format will be default format(TEXT).
99/// - If the length of provided format is 1, all format will be the same as this only format.
100/// - If the length of provided format > 1, provided format should be the actual format.
101#[derive(Debug, Clone)]
102pub struct FormatIterator<'a, 'b>
103where
104    'a: 'b,
105{
106    _formats: &'a [Format],
107    format_iter: Iter<'b, Format>,
108    actual_len: usize,
109    default_format: Format,
110}
111
112impl<'a> FormatIterator<'a, '_> {
113    pub fn new(provided_formats: &'a [Format], actual_len: usize) -> Result<Self, String> {
114        if !provided_formats.is_empty()
115            && provided_formats.len() != 1
116            && provided_formats.len() != actual_len
117        {
118            return Err(format!(
119                "format codes length {} is not 0, 1 or equal to actual length {}",
120                provided_formats.len(),
121                actual_len
122            ));
123        }
124
125        let default_format = provided_formats.first().copied().unwrap_or(Format::Text);
126
127        Ok(Self {
128            _formats: provided_formats,
129            default_format,
130            format_iter: provided_formats.iter(),
131            actual_len,
132        })
133    }
134}
135
136impl Iterator for FormatIterator<'_, '_> {
137    type Item = Format;
138
139    fn next(&mut self) -> Option<Self::Item> {
140        if self.actual_len == 0 {
141            return None;
142        }
143
144        self.actual_len -= 1;
145
146        Some(
147            self.format_iter
148                .next()
149                .copied()
150                .unwrap_or(self.default_format),
151        )
152    }
153
154    fn size_hint(&self) -> (usize, Option<usize>) {
155        (self.actual_len, Some(self.actual_len))
156    }
157}
158
159impl ExactSizeIterator for FormatIterator<'_, '_> {}
160unsafe impl TrustedLen for FormatIterator<'_, '_> {}