1use std::vec::IntoIter;
16
17use futures::stream::FusedStream;
18use futures::{StreamExt, TryStreamExt};
19use postgres_types::FromSql;
20
21use crate::error::{PsqlError, PsqlResult};
22use crate::pg_message::{BeCommandCompleteMessage, BeMessage};
23use crate::pg_protocol::{PgByteStream, PgStream};
24use crate::pg_response::{PgResponse, ValuesStream};
25use crate::types::{Format, Row};
26
27pub struct ResultCache<VS>
28where
29 VS: ValuesStream,
30{
31 result: PgResponse<VS>,
32 row_cache: IntoIter<Row>,
33}
34
35impl<VS> ResultCache<VS>
36where
37 VS: ValuesStream,
38{
39 pub fn new(result: PgResponse<VS>) -> Self {
40 ResultCache {
41 result,
42 row_cache: vec![].into_iter(),
43 }
44 }
45
46 pub async fn consume<S: PgByteStream>(
48 &mut self,
49 row_limit: usize,
50 msg_stream: &mut PgStream<S>,
51 ) -> PsqlResult<bool> {
52 for notice in self.result.notices() {
53 msg_stream.write_no_flush(&BeMessage::NoticeResponse(notice))?;
54 }
55
56 let status = self.result.status();
57 if let Some(ref application_name) = status.application_name {
58 msg_stream.write_no_flush(&BeMessage::ParameterStatus(
59 crate::pg_message::BeParameterStatusMessage::ApplicationName(application_name),
60 ))?;
61 }
62
63 if self.result.is_empty() {
64 self.result.run_callback().await?;
66
67 msg_stream.write_no_flush(&BeMessage::EmptyQueryResponse)?;
68 return Ok(true);
69 }
70
71 let mut query_end = false;
72 if self.result.is_query() {
73 let mut query_row_count = 0;
74
75 while row_limit == 0 || query_row_count < row_limit {
79 if self.row_cache.len() > 0 {
80 for row in self.row_cache.by_ref() {
81 msg_stream.write_no_flush(&BeMessage::DataRow(&row))?;
82 query_row_count += 1;
83 if row_limit > 0 && query_row_count >= row_limit {
84 break;
85 }
86 }
87 } else {
88 self.row_cache = match self
89 .result
90 .values_stream()
91 .try_next()
92 .await
93 .map_err(PsqlError::ExtendedExecuteError)?
94 {
95 Some(rows) => rows.into_iter(),
96 _ => {
97 query_end = true;
98 break;
99 }
100 };
101 }
102 }
103
104 if self.row_cache.len() == 0 && self.result.values_stream().peekable().is_terminated() {
107 query_end = true;
108 }
109 if query_end {
110 self.result.run_callback().await?;
112
113 msg_stream.write_no_flush(&BeMessage::CommandComplete(
114 BeCommandCompleteMessage {
115 stmt_type: self.result.stmt_type(),
116 rows_cnt: query_row_count as i32,
117 },
118 ))?;
119 } else {
120 msg_stream.write_no_flush(&BeMessage::PortalSuspended)?;
121 }
122 } else if self.result.stmt_type().is_dml() && !self.result.stmt_type().is_returning() {
123 let first_row_set = self.result.values_stream().next().await;
124 let first_row_set = match first_row_set {
125 None => {
126 return Err(PsqlError::Uncategorized(
127 anyhow::anyhow!("no affected rows in output").into(),
128 ));
129 }
130 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
131 };
132 let affected_rows_str = first_row_set[0].values()[0]
133 .as_ref()
134 .expect("compute node should return affected rows in output");
135
136 let affected_rows_cnt: i32 = match self.result.row_cnt_format() {
137 Some(Format::Binary) => {
138 i64::from_sql(&postgres_types::Type::INT8, affected_rows_str)
139 .unwrap()
140 .try_into()
141 .expect("affected rows count large than i64")
142 }
143 Some(Format::Text) => String::from_utf8(affected_rows_str.to_vec())
144 .unwrap()
145 .parse()
146 .unwrap_or_default(),
147 None => panic!("affected rows count should be set"),
148 };
149
150 self.result.run_callback().await?;
152
153 msg_stream.write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
154 stmt_type: self.result.stmt_type(),
155 rows_cnt: affected_rows_cnt,
156 }))?;
157
158 query_end = true;
159 } else {
160 self.result.run_callback().await?;
162
163 msg_stream.write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
164 stmt_type: self.result.stmt_type(),
165 rows_cnt: self
166 .result
167 .affected_rows_cnt()
168 .expect("row count should be set"),
169 }))?;
170
171 query_end = true;
172 }
173
174 Ok(query_end)
175 }
176}