1use std::vec::IntoIter;
16
17use futures::stream::FusedStream;
18use futures::{StreamExt, TryStreamExt};
19use postgres_types::FromSql;
20
21use crate::error::{PsqlError, PsqlResult};
22use crate::pg_message::{BeCommandCompleteMessage, BeMessage};
23use crate::pg_protocol::{PgByteStream, PgStream};
24use crate::pg_response::{PgResponse, ValuesStream};
25use crate::types::{Format, Row};
26
27pub struct ResultCache<VS>
28where
29 VS: ValuesStream,
30{
31 result: PgResponse<VS>,
32 row_cache: IntoIter<Row>,
33}
34
35impl<VS> ResultCache<VS>
36where
37 VS: ValuesStream,
38{
39 pub fn new(result: PgResponse<VS>) -> Self {
40 ResultCache {
41 result,
42 row_cache: vec![].into_iter(),
43 }
44 }
45
46 pub async fn consume<S: PgByteStream>(
48 &mut self,
49 row_limit: usize,
50 msg_stream: &mut PgStream<S>,
51 ) -> PsqlResult<bool> {
52 for notice in self.result.notices() {
53 msg_stream.write_no_flush(BeMessage::NoticeResponse(notice))?;
54 }
55
56 let status = self.result.status();
57 if let Some(ref application_name) = status.application_name {
58 msg_stream.write_no_flush(BeMessage::ParameterStatus(
59 crate::pg_message::BeParameterStatusMessage::ApplicationName(application_name),
60 ))?;
61 }
62
63 if self.result.is_empty() {
64 self.result.run_callback().await?;
66
67 msg_stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
68 return Ok(true);
69 }
70
71 let mut query_end = false;
72 if self.result.is_copy_query_to_stdout() {
73 msg_stream.write_no_flush(BeMessage::CopyOutResponse(self.result.row_desc().len()))?;
74
75 let mut count = 0;
76 while let Some(row_set) = self.result.values_stream().next().await {
77 let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
78 for row in row_set {
79 msg_stream.write_no_flush(BeMessage::CopyData(&row))?;
80 count += 1;
81 }
82 }
83
84 msg_stream.write_no_flush(BeMessage::CopyDone)?;
85
86 self.result.run_callback().await?;
88
89 msg_stream.write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
90 stmt_type: self.result.stmt_type(),
91 rows_cnt: count,
92 }))?;
93
94 query_end = true;
95 } else if self.result.is_query() {
96 let mut query_row_count = 0;
97
98 while row_limit == 0 || query_row_count < row_limit {
102 if self.row_cache.len() > 0 {
103 for row in self.row_cache.by_ref() {
104 msg_stream.write_no_flush(BeMessage::DataRow(&row))?;
105 query_row_count += 1;
106 if row_limit > 0 && query_row_count >= row_limit {
107 break;
108 }
109 }
110 } else {
111 self.row_cache = match self
112 .result
113 .values_stream()
114 .try_next()
115 .await
116 .map_err(PsqlError::ExtendedExecuteError)?
117 {
118 Some(rows) => rows.into_iter(),
119 _ => {
120 query_end = true;
121 break;
122 }
123 };
124 }
125 }
126
127 if self.row_cache.len() == 0 && self.result.values_stream().peekable().is_terminated() {
130 query_end = true;
131 }
132 if query_end {
133 self.result.run_callback().await?;
135
136 msg_stream.write_no_flush(BeMessage::CommandComplete(
137 BeCommandCompleteMessage {
138 stmt_type: self.result.stmt_type(),
139 rows_cnt: query_row_count as i32,
140 },
141 ))?;
142 } else {
143 msg_stream.write_no_flush(BeMessage::PortalSuspended)?;
144 }
145 } else if self.result.stmt_type().is_dml() && !self.result.stmt_type().is_returning() {
146 let first_row_set = self.result.values_stream().next().await;
147 let first_row_set = match first_row_set {
148 None => {
149 return Err(PsqlError::Uncategorized(
150 "no affected rows in output".into(),
151 ));
152 }
153 Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
154 };
155 let affected_rows_str = first_row_set[0].values()[0]
156 .as_ref()
157 .expect("compute node should return affected rows in output");
158
159 let affected_rows_cnt: i32 = match self.result.row_cnt_format() {
160 Some(Format::Binary) => {
161 i64::from_sql(&postgres_types::Type::INT8, affected_rows_str)
162 .unwrap()
163 .try_into()
164 .expect("affected rows count large than i64")
165 }
166 Some(Format::Text) => String::from_utf8(affected_rows_str.to_vec())
167 .unwrap()
168 .parse()
169 .unwrap_or_default(),
170 None => panic!("affected rows count should be set"),
171 };
172
173 self.result.run_callback().await?;
175
176 msg_stream.write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
177 stmt_type: self.result.stmt_type(),
178 rows_cnt: affected_rows_cnt,
179 }))?;
180
181 query_end = true;
182 } else {
183 self.result.run_callback().await?;
185
186 msg_stream.write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
187 stmt_type: self.result.stmt_type(),
188 rows_cnt: self
189 .result
190 .affected_rows_cnt()
191 .expect("row count should be set"),
192 }))?;
193
194 query_end = true;
195 }
196
197 Ok(query_end)
198 }
199}