pgwire/
pg_extended.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::vec::IntoIter;
16
17use futures::stream::FusedStream;
18use futures::{StreamExt, TryStreamExt};
19use postgres_types::FromSql;
20
21use crate::error::{PsqlError, PsqlResult};
22use crate::pg_message::{BeCommandCompleteMessage, BeMessage};
23use crate::pg_protocol::{PgByteStream, PgStream};
24use crate::pg_response::{PgResponse, ValuesStream};
25use crate::types::{Format, Row};
26
27pub struct ResultCache<VS>
28where
29    VS: ValuesStream,
30{
31    result: PgResponse<VS>,
32    row_cache: IntoIter<Row>,
33}
34
35impl<VS> ResultCache<VS>
36where
37    VS: ValuesStream,
38{
39    pub fn new(result: PgResponse<VS>) -> Self {
40        ResultCache {
41            result,
42            row_cache: vec![].into_iter(),
43        }
44    }
45
46    /// Return indicate whether the result is consumed completely.
47    pub async fn consume<S: PgByteStream>(
48        &mut self,
49        row_limit: usize,
50        msg_stream: &mut PgStream<S>,
51    ) -> PsqlResult<bool> {
52        for notice in self.result.notices() {
53            msg_stream.write_no_flush(&BeMessage::NoticeResponse(notice))?;
54        }
55
56        let status = self.result.status();
57        if let Some(ref application_name) = status.application_name {
58            msg_stream.write_no_flush(&BeMessage::ParameterStatus(
59                crate::pg_message::BeParameterStatusMessage::ApplicationName(application_name),
60            ))?;
61        }
62
63        if self.result.is_empty() {
64            // Run the callback before sending the response.
65            self.result.run_callback().await?;
66
67            msg_stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
68            return Ok(true);
69        }
70
71        let mut query_end = false;
72        if self.result.is_query() {
73            let mut query_row_count = 0;
74
75            // fetch row data
76            // if row_limit is 0, fetch all rows
77            // if row_limit > 0, fetch row_limit rows
78            while row_limit == 0 || query_row_count < row_limit {
79                if self.row_cache.len() > 0 {
80                    for row in self.row_cache.by_ref() {
81                        msg_stream.write_no_flush(&BeMessage::DataRow(&row))?;
82                        query_row_count += 1;
83                        if row_limit > 0 && query_row_count >= row_limit {
84                            break;
85                        }
86                    }
87                } else {
88                    self.row_cache = match self
89                        .result
90                        .values_stream()
91                        .try_next()
92                        .await
93                        .map_err(PsqlError::ExtendedExecuteError)?
94                    {
95                        Some(rows) => rows.into_iter(),
96                        _ => {
97                            query_end = true;
98                            break;
99                        }
100                    };
101                }
102            }
103
104            // Check if the result is consumed completely.
105            // If not, cache the result.
106            if self.row_cache.len() == 0 && self.result.values_stream().peekable().is_terminated() {
107                query_end = true;
108            }
109            if query_end {
110                // Run the callback before sending the `CommandComplete` message.
111                self.result.run_callback().await?;
112
113                msg_stream.write_no_flush(&BeMessage::CommandComplete(
114                    BeCommandCompleteMessage {
115                        stmt_type: self.result.stmt_type(),
116                        rows_cnt: query_row_count as i32,
117                    },
118                ))?;
119            } else {
120                msg_stream.write_no_flush(&BeMessage::PortalSuspended)?;
121            }
122        } else if self.result.stmt_type().is_dml() && !self.result.stmt_type().is_returning() {
123            let first_row_set = self.result.values_stream().next().await;
124            let first_row_set = match first_row_set {
125                None => {
126                    return Err(PsqlError::Uncategorized(
127                        anyhow::anyhow!("no affected rows in output").into(),
128                    ));
129                }
130                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
131            };
132            let affected_rows_str = first_row_set[0].values()[0]
133                .as_ref()
134                .expect("compute node should return affected rows in output");
135
136            let affected_rows_cnt: i32 = match self.result.row_cnt_format() {
137                Some(Format::Binary) => {
138                    i64::from_sql(&postgres_types::Type::INT8, affected_rows_str)
139                        .unwrap()
140                        .try_into()
141                        .expect("affected rows count large than i64")
142                }
143                Some(Format::Text) => String::from_utf8(affected_rows_str.to_vec())
144                    .unwrap()
145                    .parse()
146                    .unwrap_or_default(),
147                None => panic!("affected rows count should be set"),
148            };
149
150            // Run the callback before sending the `CommandComplete` message.
151            self.result.run_callback().await?;
152
153            msg_stream.write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
154                stmt_type: self.result.stmt_type(),
155                rows_cnt: affected_rows_cnt,
156            }))?;
157
158            query_end = true;
159        } else {
160            // Run the callback before sending the `CommandComplete` message.
161            self.result.run_callback().await?;
162
163            msg_stream.write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
164                stmt_type: self.result.stmt_type(),
165                rows_cnt: self
166                    .result
167                    .affected_rows_cnt()
168                    .expect("row count should be set"),
169            }))?;
170
171            query_end = true;
172        }
173
174        Ok(query_end)
175    }
176}