1use std::fmt::Formatter;
16use std::pin::Pin;
17
18use futures::{Future, FutureExt, Stream, StreamExt};
19use risingwave_sqlparser::ast::Statement;
20
21use crate::error::PsqlError;
22use crate::pg_field_descriptor::PgFieldDescriptor;
23use crate::pg_protocol::ParameterStatus;
24use crate::pg_server::BoxedError;
25use crate::types::{Format, Row};
26
27pub type RowSet = Vec<Row>;
28pub type RowSetResult = Result<RowSet, BoxedError>;
29
30pub trait ValuesStream = Stream<Item = RowSetResult> + Unpin + Send;
31
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
34pub enum StatementType {
35 INSERT,
36 INSERT_RETURNING,
37 DELETE,
38 DELETE_RETURNING,
39 UPDATE,
40 UPDATE_RETURNING,
41 SELECT,
42 MOVE,
43 FETCH,
44 COPY,
45 EXPLAIN,
46 CLOSE_CURSOR,
47 CREATE_TABLE,
48 CREATE_MATERIALIZED_VIEW,
49 CREATE_VIEW,
50 CREATE_SOURCE,
51 CREATE_SINK,
52 CREATE_SUBSCRIPTION,
53 CREATE_DATABASE,
54 CREATE_SCHEMA,
55 CREATE_USER,
56 CREATE_INDEX,
57 CREATE_AGGREGATE,
58 CREATE_FUNCTION,
59 CREATE_CONNECTION,
60 CREATE_SECRET,
61 COMMENT,
62 DECLARE_CURSOR,
63 DESCRIBE,
64 GRANT_PRIVILEGE,
65 DISCARD,
66 DROP_TABLE,
67 DROP_MATERIALIZED_VIEW,
68 DROP_VIEW,
69 DROP_INDEX,
70 DROP_FUNCTION,
71 DROP_AGGREGATE,
72 DROP_SOURCE,
73 DROP_SINK,
74 DROP_SUBSCRIPTION,
75 DROP_SCHEMA,
76 DROP_DATABASE,
77 DROP_USER,
78 DROP_CONNECTION,
79 DROP_SECRET,
80 ALTER_DATABASE,
81 ALTER_DEFAULT_PRIVILEGES,
82 ALTER_SCHEMA,
83 ALTER_INDEX,
84 ALTER_VIEW,
85 ALTER_TABLE,
86 ALTER_MATERIALIZED_VIEW,
87 ALTER_SINK,
88 ALTER_SUBSCRIPTION,
89 ALTER_SOURCE,
90 ALTER_FUNCTION,
91 ALTER_CONNECTION,
92 ALTER_SYSTEM,
93 ALTER_SECRET,
94 ALTER_FRAGMENT,
95 REVOKE_PRIVILEGE,
96 ORDER_BY,
99 SET_VARIABLE,
100 SHOW_VARIABLE,
101 SHOW_COMMAND,
102 START_TRANSACTION,
103 UPDATE_USER,
104 ABORT,
105 FLUSH,
106 REFRESH_TABLE,
107 OTHER,
108 EMPTY,
110 BEGIN,
111 COMMIT,
112 ROLLBACK,
113 SET_TRANSACTION,
114 CANCEL_COMMAND,
115 FETCH_CURSOR,
116 WAIT,
117 KILL,
118 BACKUP,
119 DELETE_META_SNAPSHOTS,
120 RECOVER,
121 USE,
122 PREPARE,
123 DEALLOCATE,
124 VACUUM,
125}
126
127impl std::fmt::Display for StatementType {
128 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
129 write!(f, "{:?}", self)
130 }
131}
132
133pub trait Callback = Future<Output = Result<(), BoxedError>> + Send;
134
135pub type BoxedCallback = Pin<Box<dyn Callback>>;
136
137pub struct PgResponse<VS> {
138 stmt_type: StatementType,
139 is_copy_query_to_stdout: bool,
140
141 row_cnt: Option<i32>,
144 row_cnt_format: Option<Format>,
146 notices: Vec<String>,
147 values_stream: Option<VS>,
148 callback: Option<BoxedCallback>,
149 row_desc: Vec<PgFieldDescriptor>,
150 status: ParameterStatus,
151}
152
153pub struct PgResponseBuilder<VS> {
154 stmt_type: StatementType,
155 row_cnt: Option<i32>,
158 row_cnt_format: Option<Format>,
160 notices: Vec<String>,
161 values_stream: Option<VS>,
162 callback: Option<BoxedCallback>,
163 row_desc: Vec<PgFieldDescriptor>,
164 status: ParameterStatus,
165}
166
167impl<VS> From<PgResponseBuilder<VS>> for PgResponse<VS> {
168 fn from(builder: PgResponseBuilder<VS>) -> Self {
169 Self {
170 stmt_type: builder.stmt_type,
171 is_copy_query_to_stdout: false, row_cnt: builder.row_cnt,
173 row_cnt_format: builder.row_cnt_format,
174 notices: builder.notices,
175 values_stream: builder.values_stream,
176 callback: builder.callback,
177 row_desc: builder.row_desc,
178 status: builder.status,
179 }
180 }
181}
182
183impl<VS> PgResponseBuilder<VS> {
184 pub fn empty(stmt_type: StatementType) -> Self {
185 let row_cnt = if stmt_type.is_query() { None } else { Some(0) };
186 Self {
187 stmt_type,
188 row_cnt,
189 row_cnt_format: None,
190 notices: vec![],
191 values_stream: None,
192 callback: None,
193 row_desc: vec![],
194 status: Default::default(),
195 }
196 }
197
198 pub fn row_cnt(self, row_cnt: i32) -> Self {
199 Self {
200 row_cnt: Some(row_cnt),
201 ..self
202 }
203 }
204
205 pub fn row_cnt_opt(self, row_cnt: Option<i32>) -> Self {
206 Self { row_cnt, ..self }
207 }
208
209 pub fn row_cnt_format_opt(self, row_cnt_format: Option<Format>) -> Self {
210 Self {
211 row_cnt_format,
212 ..self
213 }
214 }
215
216 pub fn values(self, values_stream: VS, row_desc: Vec<PgFieldDescriptor>) -> Self {
217 Self {
218 values_stream: Some(values_stream),
219 row_desc,
220 ..self
221 }
222 }
223
224 pub fn callback(self, callback: impl Callback + 'static) -> Self {
225 Self {
226 callback: Some(callback.boxed()),
227 ..self
228 }
229 }
230
231 pub fn notice(self, notice: impl ToString) -> Self {
232 let mut notices = self.notices;
233 notices.push(notice.to_string());
234 Self { notices, ..self }
235 }
236
237 pub fn status(self, status: ParameterStatus) -> Self {
238 Self { status, ..self }
239 }
240}
241
242impl<VS> std::fmt::Debug for PgResponse<VS> {
243 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
244 f.debug_struct("PgResponse")
245 .field("stmt_type", &self.stmt_type)
246 .field("row_cnt", &self.row_cnt)
247 .field("notices", &self.notices)
248 .field("row_desc", &self.row_desc)
249 .finish()
250 }
251}
252
253impl StatementType {
254 pub fn infer_from_statement(stmt: &Statement) -> Result<Self, String> {
255 match stmt {
256 Statement::Query(_) => Ok(StatementType::SELECT),
257 Statement::Insert { returning, .. } => {
258 if returning.is_empty() {
259 Ok(StatementType::INSERT)
260 } else {
261 Ok(StatementType::INSERT_RETURNING)
262 }
263 }
264 Statement::Delete { returning, .. } => {
265 if returning.is_empty() {
266 Ok(StatementType::DELETE)
267 } else {
268 Ok(StatementType::DELETE_RETURNING)
269 }
270 }
271 Statement::Update { returning, .. } => {
272 if returning.is_empty() {
273 Ok(StatementType::UPDATE)
274 } else {
275 Ok(StatementType::UPDATE_RETURNING)
276 }
277 }
278 Statement::Copy { .. } => Ok(StatementType::COPY),
279 Statement::CreateTable { .. } => Ok(StatementType::CREATE_TABLE),
280 Statement::CreateIndex { .. } => Ok(StatementType::CREATE_INDEX),
281 Statement::CreateSchema { .. } => Ok(StatementType::CREATE_SCHEMA),
282 Statement::CreateSource { .. } => Ok(StatementType::CREATE_SOURCE),
283 Statement::CreateSink { .. } => Ok(StatementType::CREATE_SINK),
284 Statement::CreateFunction { .. } => Ok(StatementType::CREATE_FUNCTION),
285 Statement::CreateDatabase { .. } => Ok(StatementType::CREATE_DATABASE),
286 Statement::CreateUser { .. } => Ok(StatementType::CREATE_USER),
287 Statement::CreateView { materialized, .. } => {
288 if *materialized {
289 Ok(StatementType::CREATE_MATERIALIZED_VIEW)
290 } else {
291 Ok(StatementType::CREATE_VIEW)
292 }
293 }
294 Statement::AlterTable { .. } => Ok(StatementType::ALTER_TABLE),
295 Statement::AlterSystem { .. } => Ok(StatementType::ALTER_SYSTEM),
296 Statement::AlterFragment { .. } => Ok(StatementType::ALTER_FRAGMENT),
297 Statement::DropFunction { .. } => Ok(StatementType::DROP_FUNCTION),
298 Statement::Discard(..) => Ok(StatementType::DISCARD),
299 Statement::SetVariable { .. } => Ok(StatementType::SET_VARIABLE),
300 Statement::ShowVariable { .. } => Ok(StatementType::SHOW_VARIABLE),
301 Statement::StartTransaction { .. } => Ok(StatementType::START_TRANSACTION),
302 Statement::Begin { .. } => Ok(StatementType::BEGIN),
303 Statement::Abort => Ok(StatementType::ABORT),
304 Statement::Commit { .. } => Ok(StatementType::COMMIT),
305 Statement::Rollback { .. } => Ok(StatementType::ROLLBACK),
306 Statement::Grant { .. } => Ok(StatementType::GRANT_PRIVILEGE),
307 Statement::Revoke { .. } => Ok(StatementType::REVOKE_PRIVILEGE),
308 Statement::Describe { .. } => Ok(StatementType::DESCRIBE),
309 Statement::ShowCreateObject { .. } | Statement::ShowObjects { .. } => {
310 Ok(StatementType::SHOW_COMMAND)
311 }
312 Statement::Drop(stmt) => match stmt.object_type {
313 risingwave_sqlparser::ast::ObjectType::Table => Ok(StatementType::DROP_TABLE),
314 risingwave_sqlparser::ast::ObjectType::View => Ok(StatementType::DROP_VIEW),
315 risingwave_sqlparser::ast::ObjectType::MaterializedView => {
316 Ok(StatementType::DROP_MATERIALIZED_VIEW)
317 }
318 risingwave_sqlparser::ast::ObjectType::Index => Ok(StatementType::DROP_INDEX),
319 risingwave_sqlparser::ast::ObjectType::Schema => Ok(StatementType::DROP_SCHEMA),
320 risingwave_sqlparser::ast::ObjectType::Source => Ok(StatementType::DROP_SOURCE),
321 risingwave_sqlparser::ast::ObjectType::Sink => Ok(StatementType::DROP_SINK),
322 risingwave_sqlparser::ast::ObjectType::Database => Ok(StatementType::DROP_DATABASE),
323 risingwave_sqlparser::ast::ObjectType::User => Ok(StatementType::DROP_USER),
324 risingwave_sqlparser::ast::ObjectType::Connection => {
325 Ok(StatementType::DROP_CONNECTION)
326 }
327 risingwave_sqlparser::ast::ObjectType::Secret => Ok(StatementType::DROP_SECRET),
328 risingwave_sqlparser::ast::ObjectType::Subscription => {
329 Ok(StatementType::DROP_SUBSCRIPTION)
330 }
331 },
332 Statement::Explain { .. } => Ok(StatementType::EXPLAIN),
333 Statement::DeclareCursor { .. } => Ok(StatementType::DECLARE_CURSOR),
334 Statement::FetchCursor { .. } => Ok(StatementType::FETCH_CURSOR),
335 Statement::CloseCursor { .. } => Ok(StatementType::CLOSE_CURSOR),
336 Statement::Flush => Ok(StatementType::FLUSH),
337 Statement::Wait => Ok(StatementType::WAIT),
338 Statement::Backup => Ok(StatementType::BACKUP),
339 Statement::DeleteMetaSnapshots { .. } => Ok(StatementType::DELETE_META_SNAPSHOTS),
340 Statement::Recover => Ok(StatementType::RECOVER),
341 Statement::Use { .. } => Ok(StatementType::USE),
342 Statement::Vacuum { .. } => Ok(StatementType::VACUUM),
343 _ => Err("unsupported statement type".to_owned()),
344 }
345 }
346
347 pub fn is_command(&self) -> bool {
348 matches!(
349 self,
350 StatementType::INSERT
351 | StatementType::DELETE
352 | StatementType::UPDATE
353 | StatementType::MOVE
354 | StatementType::COPY
355 | StatementType::FETCH
356 | StatementType::SELECT
357 | StatementType::INSERT_RETURNING
358 | StatementType::DELETE_RETURNING
359 | StatementType::UPDATE_RETURNING
360 )
361 }
362
363 pub fn is_dml(&self) -> bool {
364 matches!(
365 self,
366 StatementType::INSERT
367 | StatementType::DELETE
368 | StatementType::UPDATE
369 | StatementType::INSERT_RETURNING
370 | StatementType::DELETE_RETURNING
371 | StatementType::UPDATE_RETURNING
372 )
373 }
374
375 pub fn is_query(&self) -> bool {
376 matches!(
377 self,
378 StatementType::SELECT
379 | StatementType::EXPLAIN
380 | StatementType::SHOW_COMMAND
381 | StatementType::SHOW_VARIABLE
382 | StatementType::DESCRIBE
383 | StatementType::INSERT_RETURNING
384 | StatementType::DELETE_RETURNING
385 | StatementType::UPDATE_RETURNING
386 | StatementType::CANCEL_COMMAND
387 | StatementType::BACKUP
388 | StatementType::FETCH_CURSOR
389 )
390 }
391
392 pub fn is_returning(&self) -> bool {
393 matches!(
394 self,
395 StatementType::INSERT_RETURNING
396 | StatementType::DELETE_RETURNING
397 | StatementType::UPDATE_RETURNING
398 )
399 }
400}
401
402impl<VS> PgResponse<VS>
403where
404 VS: ValuesStream,
405{
406 pub fn builder(stmt_type: StatementType) -> PgResponseBuilder<VS> {
407 PgResponseBuilder::empty(stmt_type)
408 }
409
410 pub fn empty_result(stmt_type: StatementType) -> Self {
411 PgResponseBuilder::empty(stmt_type).into()
412 }
413
414 pub fn into_copy_query_to_stdout(mut self) -> Self {
415 self.is_copy_query_to_stdout = true;
416 self.stmt_type = StatementType::COPY;
417 self
418 }
419
420 pub fn stmt_type(&self) -> StatementType {
421 self.stmt_type
422 }
423
424 pub fn notices(&self) -> &[String] {
425 &self.notices
426 }
427
428 pub fn status(&self) -> &ParameterStatus {
429 &self.status
430 }
431
432 pub fn affected_rows_cnt(&self) -> Option<i32> {
433 self.row_cnt
434 }
435
436 pub fn row_cnt_format(&self) -> Option<Format> {
437 self.row_cnt_format
438 }
439
440 pub fn is_query(&self) -> bool {
441 self.stmt_type.is_query()
442 }
443
444 pub fn is_empty(&self) -> bool {
445 self.stmt_type == StatementType::EMPTY
446 }
447
448 pub fn is_copy_query_to_stdout(&self) -> bool {
449 self.is_copy_query_to_stdout
450 }
451
452 pub fn row_desc(&self) -> &[PgFieldDescriptor] {
453 &self.row_desc
454 }
455
456 pub fn values_stream(&mut self) -> &mut VS {
457 self.values_stream.as_mut().expect("no values stream")
458 }
459
460 pub async fn run_callback(&mut self) -> Result<(), PsqlError> {
465 if let Some(values_stream) = &mut self.values_stream {
467 assert!(values_stream.next().await.is_none());
468 }
469
470 if let Some(callback) = self.callback.take() {
471 callback.await.map_err(PsqlError::SimpleQueryError)?;
472 }
473 Ok(())
474 }
475}