pgwire/
pg_extended.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::vec::IntoIter;
16
17use futures::stream::FusedStream;
18use futures::{StreamExt, TryStreamExt};
19use postgres_types::FromSql;
20
21use crate::error::{PsqlError, PsqlResult};
22use crate::pg_message::{BeCommandCompleteMessage, BeMessage};
23use crate::pg_protocol::{PgByteStream, PgStream};
24use crate::pg_response::{PgResponse, ValuesStream};
25use crate::types::{Format, Row};
26
27pub struct ResultCache<VS>
28where
29    VS: ValuesStream,
30{
31    result: PgResponse<VS>,
32    row_cache: IntoIter<Row>,
33}
34
35impl<VS> ResultCache<VS>
36where
37    VS: ValuesStream,
38{
39    pub fn new(result: PgResponse<VS>) -> Self {
40        ResultCache {
41            result,
42            row_cache: vec![].into_iter(),
43        }
44    }
45
46    /// Return indicate whether the result is consumed completely.
47    pub async fn consume<S: PgByteStream>(
48        &mut self,
49        row_limit: usize,
50        msg_stream: &mut PgStream<S>,
51    ) -> PsqlResult<bool> {
52        for notice in self.result.notices() {
53            msg_stream.write_no_flush(BeMessage::NoticeResponse(notice))?;
54        }
55
56        let status = self.result.status();
57        if let Some(ref application_name) = status.application_name {
58            msg_stream.write_no_flush(BeMessage::ParameterStatus(
59                crate::pg_message::BeParameterStatusMessage::ApplicationName(application_name),
60            ))?;
61        }
62
63        if self.result.is_empty() {
64            // Run the callback before sending the response.
65            self.result.run_callback().await?;
66
67            msg_stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
68            return Ok(true);
69        }
70
71        let mut query_end = false;
72        if self.result.is_copy_query_to_stdout() {
73            msg_stream.write_no_flush(BeMessage::CopyOutResponse(self.result.row_desc().len()))?;
74
75            let mut count = 0;
76            while let Some(row_set) = self.result.values_stream().next().await {
77                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
78                for row in row_set {
79                    msg_stream.write_no_flush(BeMessage::CopyData(&row))?;
80                    count += 1;
81                }
82            }
83
84            msg_stream.write_no_flush(BeMessage::CopyDone)?;
85
86            // Run the callback before sending the `CommandComplete` message.
87            self.result.run_callback().await?;
88
89            msg_stream.write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
90                stmt_type: self.result.stmt_type(),
91                rows_cnt: count,
92            }))?;
93
94            query_end = true;
95        } else if self.result.is_query() {
96            let mut query_row_count = 0;
97
98            // fetch row data
99            // if row_limit is 0, fetch all rows
100            // if row_limit > 0, fetch row_limit rows
101            while row_limit == 0 || query_row_count < row_limit {
102                if self.row_cache.len() > 0 {
103                    for row in self.row_cache.by_ref() {
104                        msg_stream.write_no_flush(BeMessage::DataRow(&row))?;
105                        query_row_count += 1;
106                        if row_limit > 0 && query_row_count >= row_limit {
107                            break;
108                        }
109                    }
110                } else {
111                    self.row_cache = match self
112                        .result
113                        .values_stream()
114                        .try_next()
115                        .await
116                        .map_err(PsqlError::ExtendedExecuteError)?
117                    {
118                        Some(rows) => rows.into_iter(),
119                        _ => {
120                            query_end = true;
121                            break;
122                        }
123                    };
124                }
125            }
126
127            // Check if the result is consumed completely.
128            // If not, cache the result.
129            if self.row_cache.len() == 0 && self.result.values_stream().peekable().is_terminated() {
130                query_end = true;
131            }
132            if query_end {
133                // Run the callback before sending the `CommandComplete` message.
134                self.result.run_callback().await?;
135
136                msg_stream.write_no_flush(BeMessage::CommandComplete(
137                    BeCommandCompleteMessage {
138                        stmt_type: self.result.stmt_type(),
139                        rows_cnt: query_row_count as i32,
140                    },
141                ))?;
142            } else {
143                msg_stream.write_no_flush(BeMessage::PortalSuspended)?;
144            }
145        } else if self.result.stmt_type().is_dml() && !self.result.stmt_type().is_returning() {
146            let first_row_set = self.result.values_stream().next().await;
147            let first_row_set = match first_row_set {
148                None => {
149                    return Err(PsqlError::Uncategorized(
150                        "no affected rows in output".into(),
151                    ));
152                }
153                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
154            };
155            let affected_rows_str = first_row_set[0].values()[0]
156                .as_ref()
157                .expect("compute node should return affected rows in output");
158
159            let affected_rows_cnt: i32 = match self.result.row_cnt_format() {
160                Some(Format::Binary) => {
161                    i64::from_sql(&postgres_types::Type::INT8, affected_rows_str)
162                        .unwrap()
163                        .try_into()
164                        .expect("affected rows count large than i64")
165                }
166                Some(Format::Text) => String::from_utf8(affected_rows_str.to_vec())
167                    .unwrap()
168                    .parse()
169                    .unwrap_or_default(),
170                None => panic!("affected rows count should be set"),
171            };
172
173            // Run the callback before sending the `CommandComplete` message.
174            self.result.run_callback().await?;
175
176            msg_stream.write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
177                stmt_type: self.result.stmt_type(),
178                rows_cnt: affected_rows_cnt,
179            }))?;
180
181            query_end = true;
182        } else {
183            // Run the callback before sending the `CommandComplete` message.
184            self.result.run_callback().await?;
185
186            msg_stream.write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
187                stmt_type: self.result.stmt_type(),
188                rows_cnt: self
189                    .result
190                    .affected_rows_cnt()
191                    .expect("row count should be set"),
192            }))?;
193
194            query_end = true;
195        }
196
197        Ok(query_end)
198    }
199}