1use std::fmt::Formatter;
16use std::pin::Pin;
17
18use futures::{Future, FutureExt, Stream, StreamExt};
19use risingwave_sqlparser::ast::Statement;
20
21use crate::error::PsqlError;
22use crate::pg_field_descriptor::PgFieldDescriptor;
23use crate::pg_protocol::ParameterStatus;
24use crate::pg_server::BoxedError;
25use crate::types::{Format, Row};
26
27pub type RowSet = Vec<Row>;
28pub type RowSetResult = Result<RowSet, BoxedError>;
29
30pub trait ValuesStream = Stream<Item = RowSetResult> + Unpin + Send;
31
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
34pub enum StatementType {
35 INSERT,
36 INSERT_RETURNING,
37 DELETE,
38 DELETE_RETURNING,
39 UPDATE,
40 UPDATE_RETURNING,
41 SELECT,
42 MOVE,
43 FETCH,
44 COPY,
45 EXPLAIN,
46 CLOSE_CURSOR,
47 CREATE_TABLE,
48 CREATE_MATERIALIZED_VIEW,
49 CREATE_VIEW,
50 CREATE_SOURCE,
51 CREATE_SINK,
52 CREATE_SUBSCRIPTION,
53 CREATE_DATABASE,
54 CREATE_SCHEMA,
55 CREATE_USER,
56 CREATE_INDEX,
57 CREATE_AGGREGATE,
58 CREATE_FUNCTION,
59 CREATE_CONNECTION,
60 CREATE_SECRET,
61 COMMENT,
62 DECLARE_CURSOR,
63 DESCRIBE,
64 GRANT_PRIVILEGE,
65 DISCARD,
66 DROP_TABLE,
67 DROP_MATERIALIZED_VIEW,
68 DROP_VIEW,
69 DROP_INDEX,
70 DROP_FUNCTION,
71 DROP_AGGREGATE,
72 DROP_SOURCE,
73 DROP_SINK,
74 DROP_SUBSCRIPTION,
75 DROP_SCHEMA,
76 DROP_DATABASE,
77 DROP_USER,
78 DROP_CONNECTION,
79 DROP_SECRET,
80 ALTER_DATABASE,
81 ALTER_DEFAULT_PRIVILEGES,
82 ALTER_SCHEMA,
83 ALTER_INDEX,
84 ALTER_VIEW,
85 ALTER_TABLE,
86 ALTER_MATERIALIZED_VIEW,
87 ALTER_SINK,
88 ALTER_SUBSCRIPTION,
89 ALTER_SOURCE,
90 ALTER_FUNCTION,
91 ALTER_CONNECTION,
92 ALTER_SYSTEM,
93 ALTER_SECRET,
94 REVOKE_PRIVILEGE,
95 ORDER_BY,
98 SET_VARIABLE,
99 SHOW_VARIABLE,
100 SHOW_COMMAND,
101 START_TRANSACTION,
102 UPDATE_USER,
103 ABORT,
104 FLUSH,
105 OTHER,
106 EMPTY,
108 BEGIN,
109 COMMIT,
110 ROLLBACK,
111 SET_TRANSACTION,
112 CANCEL_COMMAND,
113 FETCH_CURSOR,
114 WAIT,
115 KILL,
116 RECOVER,
117 USE,
118 PREPARE,
119 DEALLOCATE,
120}
121
122impl std::fmt::Display for StatementType {
123 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
124 write!(f, "{:?}", self)
125 }
126}
127
128pub trait Callback = Future<Output = Result<(), BoxedError>> + Send;
129
130pub type BoxedCallback = Pin<Box<dyn Callback>>;
131
132pub struct PgResponse<VS> {
133 stmt_type: StatementType,
134 row_cnt: Option<i32>,
137 row_cnt_format: Option<Format>,
139 notices: Vec<String>,
140 values_stream: Option<VS>,
141 callback: Option<BoxedCallback>,
142 row_desc: Vec<PgFieldDescriptor>,
143 status: ParameterStatus,
144}
145
146pub struct PgResponseBuilder<VS> {
147 stmt_type: StatementType,
148 row_cnt: Option<i32>,
151 row_cnt_format: Option<Format>,
153 notices: Vec<String>,
154 values_stream: Option<VS>,
155 callback: Option<BoxedCallback>,
156 row_desc: Vec<PgFieldDescriptor>,
157 status: ParameterStatus,
158}
159
160impl<VS> From<PgResponseBuilder<VS>> for PgResponse<VS> {
161 fn from(builder: PgResponseBuilder<VS>) -> Self {
162 Self {
163 stmt_type: builder.stmt_type,
164 row_cnt: builder.row_cnt,
165 row_cnt_format: builder.row_cnt_format,
166 notices: builder.notices,
167 values_stream: builder.values_stream,
168 callback: builder.callback,
169 row_desc: builder.row_desc,
170 status: builder.status,
171 }
172 }
173}
174
175impl<VS> PgResponseBuilder<VS> {
176 pub fn empty(stmt_type: StatementType) -> Self {
177 let row_cnt = if stmt_type.is_query() { None } else { Some(0) };
178 Self {
179 stmt_type,
180 row_cnt,
181 row_cnt_format: None,
182 notices: vec![],
183 values_stream: None,
184 callback: None,
185 row_desc: vec![],
186 status: Default::default(),
187 }
188 }
189
190 pub fn row_cnt(self, row_cnt: i32) -> Self {
191 Self {
192 row_cnt: Some(row_cnt),
193 ..self
194 }
195 }
196
197 pub fn row_cnt_opt(self, row_cnt: Option<i32>) -> Self {
198 Self { row_cnt, ..self }
199 }
200
201 pub fn row_cnt_format_opt(self, row_cnt_format: Option<Format>) -> Self {
202 Self {
203 row_cnt_format,
204 ..self
205 }
206 }
207
208 pub fn values(self, values_stream: VS, row_desc: Vec<PgFieldDescriptor>) -> Self {
209 Self {
210 values_stream: Some(values_stream),
211 row_desc,
212 ..self
213 }
214 }
215
216 pub fn callback(self, callback: impl Callback + 'static) -> Self {
217 Self {
218 callback: Some(callback.boxed()),
219 ..self
220 }
221 }
222
223 pub fn notice(self, notice: impl ToString) -> Self {
224 let mut notices = self.notices;
225 notices.push(notice.to_string());
226 Self { notices, ..self }
227 }
228
229 pub fn status(self, status: ParameterStatus) -> Self {
230 Self { status, ..self }
231 }
232}
233
234impl<VS> std::fmt::Debug for PgResponse<VS> {
235 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
236 f.debug_struct("PgResponse")
237 .field("stmt_type", &self.stmt_type)
238 .field("row_cnt", &self.row_cnt)
239 .field("notices", &self.notices)
240 .field("row_desc", &self.row_desc)
241 .finish()
242 }
243}
244
245impl StatementType {
246 pub fn infer_from_statement(stmt: &Statement) -> Result<Self, String> {
247 match stmt {
248 Statement::Query(_) => Ok(StatementType::SELECT),
249 Statement::Insert { returning, .. } => {
250 if returning.is_empty() {
251 Ok(StatementType::INSERT)
252 } else {
253 Ok(StatementType::INSERT_RETURNING)
254 }
255 }
256 Statement::Delete { returning, .. } => {
257 if returning.is_empty() {
258 Ok(StatementType::DELETE)
259 } else {
260 Ok(StatementType::DELETE_RETURNING)
261 }
262 }
263 Statement::Update { returning, .. } => {
264 if returning.is_empty() {
265 Ok(StatementType::UPDATE)
266 } else {
267 Ok(StatementType::UPDATE_RETURNING)
268 }
269 }
270 Statement::Copy { .. } => Ok(StatementType::COPY),
271 Statement::CreateTable { .. } => Ok(StatementType::CREATE_TABLE),
272 Statement::CreateIndex { .. } => Ok(StatementType::CREATE_INDEX),
273 Statement::CreateSchema { .. } => Ok(StatementType::CREATE_SCHEMA),
274 Statement::CreateSource { .. } => Ok(StatementType::CREATE_SOURCE),
275 Statement::CreateSink { .. } => Ok(StatementType::CREATE_SINK),
276 Statement::CreateFunction { .. } => Ok(StatementType::CREATE_FUNCTION),
277 Statement::CreateDatabase { .. } => Ok(StatementType::CREATE_DATABASE),
278 Statement::CreateUser { .. } => Ok(StatementType::CREATE_USER),
279 Statement::CreateView { materialized, .. } => {
280 if *materialized {
281 Ok(StatementType::CREATE_MATERIALIZED_VIEW)
282 } else {
283 Ok(StatementType::CREATE_VIEW)
284 }
285 }
286 Statement::AlterTable { .. } => Ok(StatementType::ALTER_TABLE),
287 Statement::AlterSystem { .. } => Ok(StatementType::ALTER_SYSTEM),
288 Statement::DropFunction { .. } => Ok(StatementType::DROP_FUNCTION),
289 Statement::Discard(..) => Ok(StatementType::DISCARD),
290 Statement::SetVariable { .. } => Ok(StatementType::SET_VARIABLE),
291 Statement::ShowVariable { .. } => Ok(StatementType::SHOW_VARIABLE),
292 Statement::StartTransaction { .. } => Ok(StatementType::START_TRANSACTION),
293 Statement::Begin { .. } => Ok(StatementType::BEGIN),
294 Statement::Abort => Ok(StatementType::ABORT),
295 Statement::Commit { .. } => Ok(StatementType::COMMIT),
296 Statement::Rollback { .. } => Ok(StatementType::ROLLBACK),
297 Statement::Grant { .. } => Ok(StatementType::GRANT_PRIVILEGE),
298 Statement::Revoke { .. } => Ok(StatementType::REVOKE_PRIVILEGE),
299 Statement::Describe { .. } => Ok(StatementType::DESCRIBE),
300 Statement::ShowCreateObject { .. } | Statement::ShowObjects { .. } => {
301 Ok(StatementType::SHOW_COMMAND)
302 }
303 Statement::Drop(stmt) => match stmt.object_type {
304 risingwave_sqlparser::ast::ObjectType::Table => Ok(StatementType::DROP_TABLE),
305 risingwave_sqlparser::ast::ObjectType::View => Ok(StatementType::DROP_VIEW),
306 risingwave_sqlparser::ast::ObjectType::MaterializedView => {
307 Ok(StatementType::DROP_MATERIALIZED_VIEW)
308 }
309 risingwave_sqlparser::ast::ObjectType::Index => Ok(StatementType::DROP_INDEX),
310 risingwave_sqlparser::ast::ObjectType::Schema => Ok(StatementType::DROP_SCHEMA),
311 risingwave_sqlparser::ast::ObjectType::Source => Ok(StatementType::DROP_SOURCE),
312 risingwave_sqlparser::ast::ObjectType::Sink => Ok(StatementType::DROP_SINK),
313 risingwave_sqlparser::ast::ObjectType::Database => Ok(StatementType::DROP_DATABASE),
314 risingwave_sqlparser::ast::ObjectType::User => Ok(StatementType::DROP_USER),
315 risingwave_sqlparser::ast::ObjectType::Connection => {
316 Ok(StatementType::DROP_CONNECTION)
317 }
318 risingwave_sqlparser::ast::ObjectType::Secret => Ok(StatementType::DROP_SECRET),
319 risingwave_sqlparser::ast::ObjectType::Subscription => {
320 Ok(StatementType::DROP_SUBSCRIPTION)
321 }
322 },
323 Statement::Explain { .. } => Ok(StatementType::EXPLAIN),
324 Statement::DeclareCursor { .. } => Ok(StatementType::DECLARE_CURSOR),
325 Statement::FetchCursor { .. } => Ok(StatementType::FETCH_CURSOR),
326 Statement::CloseCursor { .. } => Ok(StatementType::CLOSE_CURSOR),
327 Statement::Flush => Ok(StatementType::FLUSH),
328 Statement::Wait => Ok(StatementType::WAIT),
329 Statement::Use { .. } => Ok(StatementType::USE),
330 _ => Err("unsupported statement type".to_owned()),
331 }
332 }
333
334 pub fn is_command(&self) -> bool {
335 matches!(
336 self,
337 StatementType::INSERT
338 | StatementType::DELETE
339 | StatementType::UPDATE
340 | StatementType::MOVE
341 | StatementType::COPY
342 | StatementType::FETCH
343 | StatementType::SELECT
344 | StatementType::INSERT_RETURNING
345 | StatementType::DELETE_RETURNING
346 | StatementType::UPDATE_RETURNING
347 )
348 }
349
350 pub fn is_dml(&self) -> bool {
351 matches!(
352 self,
353 StatementType::INSERT
354 | StatementType::DELETE
355 | StatementType::UPDATE
356 | StatementType::INSERT_RETURNING
357 | StatementType::DELETE_RETURNING
358 | StatementType::UPDATE_RETURNING
359 )
360 }
361
362 pub fn is_query(&self) -> bool {
363 matches!(
364 self,
365 StatementType::SELECT
366 | StatementType::EXPLAIN
367 | StatementType::SHOW_COMMAND
368 | StatementType::SHOW_VARIABLE
369 | StatementType::DESCRIBE
370 | StatementType::INSERT_RETURNING
371 | StatementType::DELETE_RETURNING
372 | StatementType::UPDATE_RETURNING
373 | StatementType::CANCEL_COMMAND
374 | StatementType::FETCH_CURSOR
375 )
376 }
377
378 pub fn is_returning(&self) -> bool {
379 matches!(
380 self,
381 StatementType::INSERT_RETURNING
382 | StatementType::DELETE_RETURNING
383 | StatementType::UPDATE_RETURNING
384 )
385 }
386}
387
388impl<VS> PgResponse<VS>
389where
390 VS: ValuesStream,
391{
392 pub fn builder(stmt_type: StatementType) -> PgResponseBuilder<VS> {
393 PgResponseBuilder::empty(stmt_type)
394 }
395
396 pub fn empty_result(stmt_type: StatementType) -> Self {
397 PgResponseBuilder::empty(stmt_type).into()
398 }
399
400 pub fn stmt_type(&self) -> StatementType {
401 self.stmt_type
402 }
403
404 pub fn notices(&self) -> &[String] {
405 &self.notices
406 }
407
408 pub fn status(&self) -> &ParameterStatus {
409 &self.status
410 }
411
412 pub fn affected_rows_cnt(&self) -> Option<i32> {
413 self.row_cnt
414 }
415
416 pub fn row_cnt_format(&self) -> Option<Format> {
417 self.row_cnt_format
418 }
419
420 pub fn is_query(&self) -> bool {
421 self.stmt_type.is_query()
422 }
423
424 pub fn is_empty(&self) -> bool {
425 self.stmt_type == StatementType::EMPTY
426 }
427
428 pub fn row_desc(&self) -> Vec<PgFieldDescriptor> {
429 self.row_desc.clone()
430 }
431
432 pub fn values_stream(&mut self) -> &mut VS {
433 self.values_stream.as_mut().expect("no values stream")
434 }
435
436 pub async fn run_callback(&mut self) -> Result<(), PsqlError> {
441 if let Some(values_stream) = &mut self.values_stream {
443 assert!(values_stream.next().await.is_none());
444 }
445
446 if let Some(callback) = self.callback.take() {
447 callback.await.map_err(PsqlError::SimpleQueryError)?;
448 }
449 Ok(())
450 }
451}