1use std::fmt::Formatter;
16use std::pin::Pin;
17
18use futures::{Future, FutureExt, Stream, StreamExt};
19use risingwave_sqlparser::ast::Statement;
20
21use crate::error::PsqlError;
22use crate::pg_field_descriptor::PgFieldDescriptor;
23use crate::pg_protocol::ParameterStatus;
24use crate::pg_server::BoxedError;
25use crate::types::{Format, Row};
26
27pub type RowSet = Vec<Row>;
28pub type RowSetResult = Result<RowSet, BoxedError>;
29
30pub trait ValuesStream = Stream<Item = RowSetResult> + Unpin + Send;
31
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
34pub enum StatementType {
35 INSERT,
36 INSERT_RETURNING,
37 DELETE,
38 DELETE_RETURNING,
39 UPDATE,
40 UPDATE_RETURNING,
41 SELECT,
42 MOVE,
43 FETCH,
44 COPY,
45 EXPLAIN,
46 CLOSE_CURSOR,
47 CREATE_TABLE,
48 CREATE_MATERIALIZED_VIEW,
49 CREATE_VIEW,
50 CREATE_SOURCE,
51 CREATE_SINK,
52 CREATE_SUBSCRIPTION,
53 CREATE_DATABASE,
54 CREATE_SCHEMA,
55 CREATE_USER,
56 CREATE_INDEX,
57 CREATE_AGGREGATE,
58 CREATE_FUNCTION,
59 CREATE_CONNECTION,
60 CREATE_SECRET,
61 COMMENT,
62 DECLARE_CURSOR,
63 DESCRIBE,
64 GRANT_PRIVILEGE,
65 DISCARD,
66 DROP_TABLE,
67 DROP_MATERIALIZED_VIEW,
68 DROP_VIEW,
69 DROP_INDEX,
70 DROP_FUNCTION,
71 DROP_AGGREGATE,
72 DROP_SOURCE,
73 DROP_SINK,
74 DROP_SUBSCRIPTION,
75 DROP_SCHEMA,
76 DROP_DATABASE,
77 DROP_USER,
78 DROP_CONNECTION,
79 DROP_SECRET,
80 ALTER_DATABASE,
81 ALTER_DEFAULT_PRIVILEGES,
82 ALTER_SCHEMA,
83 ALTER_INDEX,
84 ALTER_VIEW,
85 ALTER_TABLE,
86 ALTER_MATERIALIZED_VIEW,
87 ALTER_SINK,
88 ALTER_SUBSCRIPTION,
89 ALTER_SOURCE,
90 ALTER_FUNCTION,
91 ALTER_CONNECTION,
92 ALTER_SYSTEM,
93 ALTER_SECRET,
94 REVOKE_PRIVILEGE,
95 ORDER_BY,
98 SET_VARIABLE,
99 SHOW_VARIABLE,
100 SHOW_COMMAND,
101 START_TRANSACTION,
102 UPDATE_USER,
103 ABORT,
104 FLUSH,
105 REFRESH_TABLE,
106 OTHER,
107 EMPTY,
109 BEGIN,
110 COMMIT,
111 ROLLBACK,
112 SET_TRANSACTION,
113 CANCEL_COMMAND,
114 FETCH_CURSOR,
115 WAIT,
116 KILL,
117 RECOVER,
118 USE,
119 PREPARE,
120 DEALLOCATE,
121 VACUUM,
122}
123
124impl std::fmt::Display for StatementType {
125 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
126 write!(f, "{:?}", self)
127 }
128}
129
130pub trait Callback = Future<Output = Result<(), BoxedError>> + Send;
131
132pub type BoxedCallback = Pin<Box<dyn Callback>>;
133
134pub struct PgResponse<VS> {
135 stmt_type: StatementType,
136 is_copy_query_to_stdout: bool,
137
138 row_cnt: Option<i32>,
141 row_cnt_format: Option<Format>,
143 notices: Vec<String>,
144 values_stream: Option<VS>,
145 callback: Option<BoxedCallback>,
146 row_desc: Vec<PgFieldDescriptor>,
147 status: ParameterStatus,
148}
149
150pub struct PgResponseBuilder<VS> {
151 stmt_type: StatementType,
152 row_cnt: Option<i32>,
155 row_cnt_format: Option<Format>,
157 notices: Vec<String>,
158 values_stream: Option<VS>,
159 callback: Option<BoxedCallback>,
160 row_desc: Vec<PgFieldDescriptor>,
161 status: ParameterStatus,
162}
163
164impl<VS> From<PgResponseBuilder<VS>> for PgResponse<VS> {
165 fn from(builder: PgResponseBuilder<VS>) -> Self {
166 Self {
167 stmt_type: builder.stmt_type,
168 is_copy_query_to_stdout: false, row_cnt: builder.row_cnt,
170 row_cnt_format: builder.row_cnt_format,
171 notices: builder.notices,
172 values_stream: builder.values_stream,
173 callback: builder.callback,
174 row_desc: builder.row_desc,
175 status: builder.status,
176 }
177 }
178}
179
180impl<VS> PgResponseBuilder<VS> {
181 pub fn empty(stmt_type: StatementType) -> Self {
182 let row_cnt = if stmt_type.is_query() { None } else { Some(0) };
183 Self {
184 stmt_type,
185 row_cnt,
186 row_cnt_format: None,
187 notices: vec![],
188 values_stream: None,
189 callback: None,
190 row_desc: vec![],
191 status: Default::default(),
192 }
193 }
194
195 pub fn row_cnt(self, row_cnt: i32) -> Self {
196 Self {
197 row_cnt: Some(row_cnt),
198 ..self
199 }
200 }
201
202 pub fn row_cnt_opt(self, row_cnt: Option<i32>) -> Self {
203 Self { row_cnt, ..self }
204 }
205
206 pub fn row_cnt_format_opt(self, row_cnt_format: Option<Format>) -> Self {
207 Self {
208 row_cnt_format,
209 ..self
210 }
211 }
212
213 pub fn values(self, values_stream: VS, row_desc: Vec<PgFieldDescriptor>) -> Self {
214 Self {
215 values_stream: Some(values_stream),
216 row_desc,
217 ..self
218 }
219 }
220
221 pub fn callback(self, callback: impl Callback + 'static) -> Self {
222 Self {
223 callback: Some(callback.boxed()),
224 ..self
225 }
226 }
227
228 pub fn notice(self, notice: impl ToString) -> Self {
229 let mut notices = self.notices;
230 notices.push(notice.to_string());
231 Self { notices, ..self }
232 }
233
234 pub fn status(self, status: ParameterStatus) -> Self {
235 Self { status, ..self }
236 }
237}
238
239impl<VS> std::fmt::Debug for PgResponse<VS> {
240 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
241 f.debug_struct("PgResponse")
242 .field("stmt_type", &self.stmt_type)
243 .field("row_cnt", &self.row_cnt)
244 .field("notices", &self.notices)
245 .field("row_desc", &self.row_desc)
246 .finish()
247 }
248}
249
250impl StatementType {
251 pub fn infer_from_statement(stmt: &Statement) -> Result<Self, String> {
252 match stmt {
253 Statement::Query(_) => Ok(StatementType::SELECT),
254 Statement::Insert { returning, .. } => {
255 if returning.is_empty() {
256 Ok(StatementType::INSERT)
257 } else {
258 Ok(StatementType::INSERT_RETURNING)
259 }
260 }
261 Statement::Delete { returning, .. } => {
262 if returning.is_empty() {
263 Ok(StatementType::DELETE)
264 } else {
265 Ok(StatementType::DELETE_RETURNING)
266 }
267 }
268 Statement::Update { returning, .. } => {
269 if returning.is_empty() {
270 Ok(StatementType::UPDATE)
271 } else {
272 Ok(StatementType::UPDATE_RETURNING)
273 }
274 }
275 Statement::Copy { .. } => Ok(StatementType::COPY),
276 Statement::CreateTable { .. } => Ok(StatementType::CREATE_TABLE),
277 Statement::CreateIndex { .. } => Ok(StatementType::CREATE_INDEX),
278 Statement::CreateSchema { .. } => Ok(StatementType::CREATE_SCHEMA),
279 Statement::CreateSource { .. } => Ok(StatementType::CREATE_SOURCE),
280 Statement::CreateSink { .. } => Ok(StatementType::CREATE_SINK),
281 Statement::CreateFunction { .. } => Ok(StatementType::CREATE_FUNCTION),
282 Statement::CreateDatabase { .. } => Ok(StatementType::CREATE_DATABASE),
283 Statement::CreateUser { .. } => Ok(StatementType::CREATE_USER),
284 Statement::CreateView { materialized, .. } => {
285 if *materialized {
286 Ok(StatementType::CREATE_MATERIALIZED_VIEW)
287 } else {
288 Ok(StatementType::CREATE_VIEW)
289 }
290 }
291 Statement::AlterTable { .. } => Ok(StatementType::ALTER_TABLE),
292 Statement::AlterSystem { .. } => Ok(StatementType::ALTER_SYSTEM),
293 Statement::DropFunction { .. } => Ok(StatementType::DROP_FUNCTION),
294 Statement::Discard(..) => Ok(StatementType::DISCARD),
295 Statement::SetVariable { .. } => Ok(StatementType::SET_VARIABLE),
296 Statement::ShowVariable { .. } => Ok(StatementType::SHOW_VARIABLE),
297 Statement::StartTransaction { .. } => Ok(StatementType::START_TRANSACTION),
298 Statement::Begin { .. } => Ok(StatementType::BEGIN),
299 Statement::Abort => Ok(StatementType::ABORT),
300 Statement::Commit { .. } => Ok(StatementType::COMMIT),
301 Statement::Rollback { .. } => Ok(StatementType::ROLLBACK),
302 Statement::Grant { .. } => Ok(StatementType::GRANT_PRIVILEGE),
303 Statement::Revoke { .. } => Ok(StatementType::REVOKE_PRIVILEGE),
304 Statement::Describe { .. } => Ok(StatementType::DESCRIBE),
305 Statement::ShowCreateObject { .. } | Statement::ShowObjects { .. } => {
306 Ok(StatementType::SHOW_COMMAND)
307 }
308 Statement::Drop(stmt) => match stmt.object_type {
309 risingwave_sqlparser::ast::ObjectType::Table => Ok(StatementType::DROP_TABLE),
310 risingwave_sqlparser::ast::ObjectType::View => Ok(StatementType::DROP_VIEW),
311 risingwave_sqlparser::ast::ObjectType::MaterializedView => {
312 Ok(StatementType::DROP_MATERIALIZED_VIEW)
313 }
314 risingwave_sqlparser::ast::ObjectType::Index => Ok(StatementType::DROP_INDEX),
315 risingwave_sqlparser::ast::ObjectType::Schema => Ok(StatementType::DROP_SCHEMA),
316 risingwave_sqlparser::ast::ObjectType::Source => Ok(StatementType::DROP_SOURCE),
317 risingwave_sqlparser::ast::ObjectType::Sink => Ok(StatementType::DROP_SINK),
318 risingwave_sqlparser::ast::ObjectType::Database => Ok(StatementType::DROP_DATABASE),
319 risingwave_sqlparser::ast::ObjectType::User => Ok(StatementType::DROP_USER),
320 risingwave_sqlparser::ast::ObjectType::Connection => {
321 Ok(StatementType::DROP_CONNECTION)
322 }
323 risingwave_sqlparser::ast::ObjectType::Secret => Ok(StatementType::DROP_SECRET),
324 risingwave_sqlparser::ast::ObjectType::Subscription => {
325 Ok(StatementType::DROP_SUBSCRIPTION)
326 }
327 },
328 Statement::Explain { .. } => Ok(StatementType::EXPLAIN),
329 Statement::DeclareCursor { .. } => Ok(StatementType::DECLARE_CURSOR),
330 Statement::FetchCursor { .. } => Ok(StatementType::FETCH_CURSOR),
331 Statement::CloseCursor { .. } => Ok(StatementType::CLOSE_CURSOR),
332 Statement::Flush => Ok(StatementType::FLUSH),
333 Statement::Wait => Ok(StatementType::WAIT),
334 Statement::Use { .. } => Ok(StatementType::USE),
335 Statement::Vacuum { .. } => Ok(StatementType::VACUUM),
336 _ => Err("unsupported statement type".to_owned()),
337 }
338 }
339
340 pub fn is_command(&self) -> bool {
341 matches!(
342 self,
343 StatementType::INSERT
344 | StatementType::DELETE
345 | StatementType::UPDATE
346 | StatementType::MOVE
347 | StatementType::COPY
348 | StatementType::FETCH
349 | StatementType::SELECT
350 | StatementType::INSERT_RETURNING
351 | StatementType::DELETE_RETURNING
352 | StatementType::UPDATE_RETURNING
353 )
354 }
355
356 pub fn is_dml(&self) -> bool {
357 matches!(
358 self,
359 StatementType::INSERT
360 | StatementType::DELETE
361 | StatementType::UPDATE
362 | StatementType::INSERT_RETURNING
363 | StatementType::DELETE_RETURNING
364 | StatementType::UPDATE_RETURNING
365 )
366 }
367
368 pub fn is_query(&self) -> bool {
369 matches!(
370 self,
371 StatementType::SELECT
372 | StatementType::EXPLAIN
373 | StatementType::SHOW_COMMAND
374 | StatementType::SHOW_VARIABLE
375 | StatementType::DESCRIBE
376 | StatementType::INSERT_RETURNING
377 | StatementType::DELETE_RETURNING
378 | StatementType::UPDATE_RETURNING
379 | StatementType::CANCEL_COMMAND
380 | StatementType::FETCH_CURSOR
381 )
382 }
383
384 pub fn is_returning(&self) -> bool {
385 matches!(
386 self,
387 StatementType::INSERT_RETURNING
388 | StatementType::DELETE_RETURNING
389 | StatementType::UPDATE_RETURNING
390 )
391 }
392}
393
394impl<VS> PgResponse<VS>
395where
396 VS: ValuesStream,
397{
398 pub fn builder(stmt_type: StatementType) -> PgResponseBuilder<VS> {
399 PgResponseBuilder::empty(stmt_type)
400 }
401
402 pub fn empty_result(stmt_type: StatementType) -> Self {
403 PgResponseBuilder::empty(stmt_type).into()
404 }
405
406 pub fn into_copy_query_to_stdout(mut self) -> Self {
407 self.is_copy_query_to_stdout = true;
408 self.stmt_type = StatementType::COPY;
409 self
410 }
411
412 pub fn stmt_type(&self) -> StatementType {
413 self.stmt_type
414 }
415
416 pub fn notices(&self) -> &[String] {
417 &self.notices
418 }
419
420 pub fn status(&self) -> &ParameterStatus {
421 &self.status
422 }
423
424 pub fn affected_rows_cnt(&self) -> Option<i32> {
425 self.row_cnt
426 }
427
428 pub fn row_cnt_format(&self) -> Option<Format> {
429 self.row_cnt_format
430 }
431
432 pub fn is_query(&self) -> bool {
433 self.stmt_type.is_query()
434 }
435
436 pub fn is_empty(&self) -> bool {
437 self.stmt_type == StatementType::EMPTY
438 }
439
440 pub fn is_copy_query_to_stdout(&self) -> bool {
441 self.is_copy_query_to_stdout
442 }
443
444 pub fn row_desc(&self) -> &[PgFieldDescriptor] {
445 &self.row_desc
446 }
447
448 pub fn values_stream(&mut self) -> &mut VS {
449 self.values_stream.as_mut().expect("no values stream")
450 }
451
452 pub async fn run_callback(&mut self) -> Result<(), PsqlError> {
457 if let Some(values_stream) = &mut self.values_stream {
459 assert!(values_stream.next().await.is_none());
460 }
461
462 if let Some(callback) = self.callback.take() {
463 callback.await.map_err(PsqlError::SimpleQueryError)?;
464 }
465 Ok(())
466 }
467}