1use std::fmt::Formatter;
16use std::pin::Pin;
17
18use futures::{Future, FutureExt, Stream, StreamExt};
19use risingwave_sqlparser::ast::Statement;
20
21use crate::error::PsqlError;
22use crate::pg_field_descriptor::PgFieldDescriptor;
23use crate::pg_protocol::ParameterStatus;
24use crate::pg_server::BoxedError;
25use crate::types::{Format, Row};
26
27pub type RowSet = Vec<Row>;
28pub type RowSetResult = Result<RowSet, BoxedError>;
29
30pub trait ValuesStream = Stream<Item = RowSetResult> + Unpin + Send;
31
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
34pub enum StatementType {
35 INSERT,
36 INSERT_RETURNING,
37 DELETE,
38 DELETE_RETURNING,
39 UPDATE,
40 UPDATE_RETURNING,
41 SELECT,
42 MOVE,
43 FETCH,
44 COPY,
45 EXPLAIN,
46 CLOSE_CURSOR,
47 CREATE_TABLE,
48 CREATE_MATERIALIZED_VIEW,
49 CREATE_VIEW,
50 CREATE_SOURCE,
51 CREATE_SINK,
52 CREATE_SUBSCRIPTION,
53 CREATE_DATABASE,
54 CREATE_SCHEMA,
55 CREATE_USER,
56 CREATE_INDEX,
57 CREATE_AGGREGATE,
58 CREATE_FUNCTION,
59 CREATE_CONNECTION,
60 CREATE_SECRET,
61 COMMENT,
62 DECLARE_CURSOR,
63 DESCRIBE,
64 GRANT_PRIVILEGE,
65 DISCARD,
66 DROP_TABLE,
67 DROP_MATERIALIZED_VIEW,
68 DROP_VIEW,
69 DROP_INDEX,
70 DROP_FUNCTION,
71 DROP_AGGREGATE,
72 DROP_SOURCE,
73 DROP_SINK,
74 DROP_SUBSCRIPTION,
75 DROP_SCHEMA,
76 DROP_DATABASE,
77 DROP_USER,
78 DROP_CONNECTION,
79 DROP_SECRET,
80 ALTER_DATABASE,
81 ALTER_DEFAULT_PRIVILEGES,
82 ALTER_SCHEMA,
83 ALTER_INDEX,
84 ALTER_VIEW,
85 ALTER_TABLE,
86 ALTER_MATERIALIZED_VIEW,
87 ALTER_SINK,
88 ALTER_SUBSCRIPTION,
89 ALTER_SOURCE,
90 ALTER_FUNCTION,
91 ALTER_CONNECTION,
92 ALTER_SYSTEM,
93 ALTER_SECRET,
94 ALTER_FRAGMENT,
95 ALTER_COMPACTION_GROUP,
96 REVOKE_PRIVILEGE,
97 ORDER_BY,
100 SET_VARIABLE,
101 SHOW_VARIABLE,
102 SHOW_COMMAND,
103 START_TRANSACTION,
104 UPDATE_USER,
105 ABORT,
106 FLUSH,
107 REFRESH_TABLE,
108 OTHER,
109 EMPTY,
111 BEGIN,
112 COMMIT,
113 ROLLBACK,
114 SET_TRANSACTION,
115 CANCEL_COMMAND,
116 FETCH_CURSOR,
117 WAIT,
118 KILL,
119 BACKUP,
120 DELETE_META_SNAPSHOTS,
121 RECOVER,
122 USE,
123 PREPARE,
124 DEALLOCATE,
125 VACUUM,
126}
127
128impl std::fmt::Display for StatementType {
129 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
130 write!(f, "{:?}", self)
131 }
132}
133
134pub trait Callback = Future<Output = Result<(), BoxedError>> + Send;
135
136pub type BoxedCallback = Pin<Box<dyn Callback>>;
137
138pub struct PgResponse<VS> {
139 stmt_type: StatementType,
140 is_copy_query_to_stdout: bool,
141
142 row_cnt: Option<i32>,
145 row_cnt_format: Option<Format>,
147 notices: Vec<String>,
148 values_stream: Option<VS>,
149 callback: Option<BoxedCallback>,
150 row_desc: Vec<PgFieldDescriptor>,
151 status: ParameterStatus,
152}
153
154pub struct PgResponseBuilder<VS> {
155 stmt_type: StatementType,
156 row_cnt: Option<i32>,
159 row_cnt_format: Option<Format>,
161 notices: Vec<String>,
162 values_stream: Option<VS>,
163 callback: Option<BoxedCallback>,
164 row_desc: Vec<PgFieldDescriptor>,
165 status: ParameterStatus,
166}
167
168impl<VS> From<PgResponseBuilder<VS>> for PgResponse<VS> {
169 fn from(builder: PgResponseBuilder<VS>) -> Self {
170 Self {
171 stmt_type: builder.stmt_type,
172 is_copy_query_to_stdout: false, row_cnt: builder.row_cnt,
174 row_cnt_format: builder.row_cnt_format,
175 notices: builder.notices,
176 values_stream: builder.values_stream,
177 callback: builder.callback,
178 row_desc: builder.row_desc,
179 status: builder.status,
180 }
181 }
182}
183
184impl<VS> PgResponseBuilder<VS> {
185 pub fn empty(stmt_type: StatementType) -> Self {
186 let row_cnt = if stmt_type.is_query() { None } else { Some(0) };
187 Self {
188 stmt_type,
189 row_cnt,
190 row_cnt_format: None,
191 notices: vec![],
192 values_stream: None,
193 callback: None,
194 row_desc: vec![],
195 status: Default::default(),
196 }
197 }
198
199 pub fn row_cnt(self, row_cnt: i32) -> Self {
200 Self {
201 row_cnt: Some(row_cnt),
202 ..self
203 }
204 }
205
206 pub fn row_cnt_opt(self, row_cnt: Option<i32>) -> Self {
207 Self { row_cnt, ..self }
208 }
209
210 pub fn row_cnt_format_opt(self, row_cnt_format: Option<Format>) -> Self {
211 Self {
212 row_cnt_format,
213 ..self
214 }
215 }
216
217 pub fn values(self, values_stream: VS, row_desc: Vec<PgFieldDescriptor>) -> Self {
218 Self {
219 values_stream: Some(values_stream),
220 row_desc,
221 ..self
222 }
223 }
224
225 pub fn callback(self, callback: impl Callback + 'static) -> Self {
226 Self {
227 callback: Some(callback.boxed()),
228 ..self
229 }
230 }
231
232 pub fn notice(self, notice: impl ToString) -> Self {
233 let mut notices = self.notices;
234 notices.push(notice.to_string());
235 Self { notices, ..self }
236 }
237
238 pub fn status(self, status: ParameterStatus) -> Self {
239 Self { status, ..self }
240 }
241}
242
243impl<VS> std::fmt::Debug for PgResponse<VS> {
244 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
245 f.debug_struct("PgResponse")
246 .field("stmt_type", &self.stmt_type)
247 .field("row_cnt", &self.row_cnt)
248 .field("notices", &self.notices)
249 .field("row_desc", &self.row_desc)
250 .finish()
251 }
252}
253
254impl StatementType {
255 pub fn infer_from_statement(stmt: &Statement) -> Result<Self, String> {
256 match stmt {
257 Statement::Query(_) => Ok(StatementType::SELECT),
258 Statement::Insert { returning, .. } => {
259 if returning.is_empty() {
260 Ok(StatementType::INSERT)
261 } else {
262 Ok(StatementType::INSERT_RETURNING)
263 }
264 }
265 Statement::Delete { returning, .. } => {
266 if returning.is_empty() {
267 Ok(StatementType::DELETE)
268 } else {
269 Ok(StatementType::DELETE_RETURNING)
270 }
271 }
272 Statement::Update { returning, .. } => {
273 if returning.is_empty() {
274 Ok(StatementType::UPDATE)
275 } else {
276 Ok(StatementType::UPDATE_RETURNING)
277 }
278 }
279 Statement::Copy { .. } => Ok(StatementType::COPY),
280 Statement::CreateTable { .. } => Ok(StatementType::CREATE_TABLE),
281 Statement::CreateIndex { .. } => Ok(StatementType::CREATE_INDEX),
282 Statement::CreateSchema { .. } => Ok(StatementType::CREATE_SCHEMA),
283 Statement::CreateSource { .. } => Ok(StatementType::CREATE_SOURCE),
284 Statement::CreateSink { .. } => Ok(StatementType::CREATE_SINK),
285 Statement::CreateFunction { .. } => Ok(StatementType::CREATE_FUNCTION),
286 Statement::CreateDatabase { .. } => Ok(StatementType::CREATE_DATABASE),
287 Statement::CreateUser { .. } => Ok(StatementType::CREATE_USER),
288 Statement::CreateView { materialized, .. } => {
289 if *materialized {
290 Ok(StatementType::CREATE_MATERIALIZED_VIEW)
291 } else {
292 Ok(StatementType::CREATE_VIEW)
293 }
294 }
295 Statement::AlterTable { .. } => Ok(StatementType::ALTER_TABLE),
296 Statement::AlterSystem { .. } => Ok(StatementType::ALTER_SYSTEM),
297 Statement::AlterFragment { .. } => Ok(StatementType::ALTER_FRAGMENT),
298 Statement::AlterCompactionGroup { .. } => Ok(StatementType::ALTER_COMPACTION_GROUP),
299 Statement::DropFunction { .. } => Ok(StatementType::DROP_FUNCTION),
300 Statement::Discard(..) => Ok(StatementType::DISCARD),
301 Statement::SetVariable { .. } => Ok(StatementType::SET_VARIABLE),
302 Statement::ShowVariable { .. } => Ok(StatementType::SHOW_VARIABLE),
303 Statement::StartTransaction { .. } => Ok(StatementType::START_TRANSACTION),
304 Statement::Begin { .. } => Ok(StatementType::BEGIN),
305 Statement::Abort => Ok(StatementType::ABORT),
306 Statement::Commit { .. } => Ok(StatementType::COMMIT),
307 Statement::Rollback { .. } => Ok(StatementType::ROLLBACK),
308 Statement::Grant { .. } => Ok(StatementType::GRANT_PRIVILEGE),
309 Statement::Revoke { .. } => Ok(StatementType::REVOKE_PRIVILEGE),
310 Statement::Describe { .. } => Ok(StatementType::DESCRIBE),
311 Statement::ShowCreateObject { .. } | Statement::ShowObjects { .. } => {
312 Ok(StatementType::SHOW_COMMAND)
313 }
314 Statement::Drop(stmt) => match stmt.object_type {
315 risingwave_sqlparser::ast::ObjectType::Table => Ok(StatementType::DROP_TABLE),
316 risingwave_sqlparser::ast::ObjectType::View => Ok(StatementType::DROP_VIEW),
317 risingwave_sqlparser::ast::ObjectType::MaterializedView => {
318 Ok(StatementType::DROP_MATERIALIZED_VIEW)
319 }
320 risingwave_sqlparser::ast::ObjectType::Index => Ok(StatementType::DROP_INDEX),
321 risingwave_sqlparser::ast::ObjectType::Schema => Ok(StatementType::DROP_SCHEMA),
322 risingwave_sqlparser::ast::ObjectType::Source => Ok(StatementType::DROP_SOURCE),
323 risingwave_sqlparser::ast::ObjectType::Sink => Ok(StatementType::DROP_SINK),
324 risingwave_sqlparser::ast::ObjectType::Database => Ok(StatementType::DROP_DATABASE),
325 risingwave_sqlparser::ast::ObjectType::User => Ok(StatementType::DROP_USER),
326 risingwave_sqlparser::ast::ObjectType::Connection => {
327 Ok(StatementType::DROP_CONNECTION)
328 }
329 risingwave_sqlparser::ast::ObjectType::Secret => Ok(StatementType::DROP_SECRET),
330 risingwave_sqlparser::ast::ObjectType::Subscription => {
331 Ok(StatementType::DROP_SUBSCRIPTION)
332 }
333 },
334 Statement::Explain { .. } => Ok(StatementType::EXPLAIN),
335 Statement::DeclareCursor { .. } => Ok(StatementType::DECLARE_CURSOR),
336 Statement::FetchCursor { .. } => Ok(StatementType::FETCH_CURSOR),
337 Statement::CloseCursor { .. } => Ok(StatementType::CLOSE_CURSOR),
338 Statement::Flush => Ok(StatementType::FLUSH),
339 Statement::Wait(_) => Ok(StatementType::WAIT),
340 Statement::Backup => Ok(StatementType::BACKUP),
341 Statement::DeleteMetaSnapshots { .. } => Ok(StatementType::DELETE_META_SNAPSHOTS),
342 Statement::Recover => Ok(StatementType::RECOVER),
343 Statement::Use { .. } => Ok(StatementType::USE),
344 Statement::Vacuum { .. } => Ok(StatementType::VACUUM),
345 _ => Err("unsupported statement type".to_owned()),
346 }
347 }
348
349 pub fn is_command(&self) -> bool {
350 matches!(
351 self,
352 StatementType::INSERT
353 | StatementType::DELETE
354 | StatementType::UPDATE
355 | StatementType::MOVE
356 | StatementType::COPY
357 | StatementType::FETCH
358 | StatementType::SELECT
359 | StatementType::INSERT_RETURNING
360 | StatementType::DELETE_RETURNING
361 | StatementType::UPDATE_RETURNING
362 )
363 }
364
365 pub fn is_dml(&self) -> bool {
366 matches!(
367 self,
368 StatementType::INSERT
369 | StatementType::DELETE
370 | StatementType::UPDATE
371 | StatementType::INSERT_RETURNING
372 | StatementType::DELETE_RETURNING
373 | StatementType::UPDATE_RETURNING
374 )
375 }
376
377 pub fn is_query(&self) -> bool {
378 matches!(
379 self,
380 StatementType::SELECT
381 | StatementType::EXPLAIN
382 | StatementType::SHOW_COMMAND
383 | StatementType::SHOW_VARIABLE
384 | StatementType::DESCRIBE
385 | StatementType::INSERT_RETURNING
386 | StatementType::DELETE_RETURNING
387 | StatementType::UPDATE_RETURNING
388 | StatementType::CANCEL_COMMAND
389 | StatementType::BACKUP
390 | StatementType::FETCH_CURSOR
391 )
392 }
393
394 pub fn is_returning(&self) -> bool {
395 matches!(
396 self,
397 StatementType::INSERT_RETURNING
398 | StatementType::DELETE_RETURNING
399 | StatementType::UPDATE_RETURNING
400 )
401 }
402}
403
404impl<VS> PgResponse<VS>
405where
406 VS: ValuesStream,
407{
408 pub fn builder(stmt_type: StatementType) -> PgResponseBuilder<VS> {
409 PgResponseBuilder::empty(stmt_type)
410 }
411
412 pub fn empty_result(stmt_type: StatementType) -> Self {
413 PgResponseBuilder::empty(stmt_type).into()
414 }
415
416 pub fn into_copy_query_to_stdout(mut self) -> Self {
417 self.is_copy_query_to_stdout = true;
418 self.stmt_type = StatementType::COPY;
419 self
420 }
421
422 pub fn stmt_type(&self) -> StatementType {
423 self.stmt_type
424 }
425
426 pub fn notices(&self) -> &[String] {
427 &self.notices
428 }
429
430 pub fn status(&self) -> &ParameterStatus {
431 &self.status
432 }
433
434 pub fn affected_rows_cnt(&self) -> Option<i32> {
435 self.row_cnt
436 }
437
438 pub fn row_cnt_format(&self) -> Option<Format> {
439 self.row_cnt_format
440 }
441
442 pub fn is_query(&self) -> bool {
443 self.stmt_type.is_query()
444 }
445
446 pub fn is_empty(&self) -> bool {
447 self.stmt_type == StatementType::EMPTY
448 }
449
450 pub fn is_copy_query_to_stdout(&self) -> bool {
451 self.is_copy_query_to_stdout
452 }
453
454 pub fn row_desc(&self) -> &[PgFieldDescriptor] {
455 &self.row_desc
456 }
457
458 pub fn values_stream(&mut self) -> &mut VS {
459 self.values_stream.as_mut().expect("no values stream")
460 }
461
462 pub async fn run_callback(&mut self) -> Result<(), PsqlError> {
467 if let Some(values_stream) = &mut self.values_stream {
469 assert!(values_stream.next().await.is_none());
470 }
471
472 if let Some(callback) = self.callback.take() {
473 callback.await.map_err(PsqlError::SimpleQueryError)?;
474 }
475 Ok(())
476 }
477}