pgwire/
pg_response.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Formatter;
16use std::pin::Pin;
17
18use futures::{Future, FutureExt, Stream, StreamExt};
19use risingwave_sqlparser::ast::Statement;
20
21use crate::error::PsqlError;
22use crate::pg_field_descriptor::PgFieldDescriptor;
23use crate::pg_protocol::ParameterStatus;
24use crate::pg_server::BoxedError;
25use crate::types::{Format, Row};
26
27pub type RowSet = Vec<Row>;
28pub type RowSetResult = Result<RowSet, BoxedError>;
29
30pub trait ValuesStream = Stream<Item = RowSetResult> + Unpin + Send;
31
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
34pub enum StatementType {
35    INSERT,
36    INSERT_RETURNING,
37    DELETE,
38    DELETE_RETURNING,
39    UPDATE,
40    UPDATE_RETURNING,
41    SELECT,
42    MOVE,
43    FETCH,
44    COPY,
45    EXPLAIN,
46    CLOSE_CURSOR,
47    CREATE_TABLE,
48    CREATE_MATERIALIZED_VIEW,
49    CREATE_VIEW,
50    CREATE_SOURCE,
51    CREATE_SINK,
52    CREATE_SUBSCRIPTION,
53    CREATE_DATABASE,
54    CREATE_SCHEMA,
55    CREATE_USER,
56    CREATE_INDEX,
57    CREATE_AGGREGATE,
58    CREATE_FUNCTION,
59    CREATE_CONNECTION,
60    CREATE_SECRET,
61    COMMENT,
62    DECLARE_CURSOR,
63    DESCRIBE,
64    GRANT_PRIVILEGE,
65    DISCARD,
66    DROP_TABLE,
67    DROP_MATERIALIZED_VIEW,
68    DROP_VIEW,
69    DROP_INDEX,
70    DROP_FUNCTION,
71    DROP_AGGREGATE,
72    DROP_SOURCE,
73    DROP_SINK,
74    DROP_SUBSCRIPTION,
75    DROP_SCHEMA,
76    DROP_DATABASE,
77    DROP_USER,
78    DROP_CONNECTION,
79    DROP_SECRET,
80    ALTER_DATABASE,
81    ALTER_DEFAULT_PRIVILEGES,
82    ALTER_SCHEMA,
83    ALTER_INDEX,
84    ALTER_VIEW,
85    ALTER_TABLE,
86    ALTER_MATERIALIZED_VIEW,
87    ALTER_SINK,
88    ALTER_SUBSCRIPTION,
89    ALTER_SOURCE,
90    ALTER_FUNCTION,
91    ALTER_CONNECTION,
92    ALTER_SYSTEM,
93    ALTER_SECRET,
94    ALTER_FRAGMENT,
95    REVOKE_PRIVILEGE,
96    // Introduce ORDER_BY statement type cuz Calcite unvalidated AST has SqlKind.ORDER_BY. Note
97    // that Statement Type is not designed to be one to one mapping with SqlKind.
98    ORDER_BY,
99    SET_VARIABLE,
100    SHOW_VARIABLE,
101    SHOW_COMMAND,
102    START_TRANSACTION,
103    UPDATE_USER,
104    ABORT,
105    FLUSH,
106    REFRESH_TABLE,
107    OTHER,
108    // EMPTY is used when query statement is empty (e.g. ";").
109    EMPTY,
110    BEGIN,
111    COMMIT,
112    ROLLBACK,
113    SET_TRANSACTION,
114    CANCEL_COMMAND,
115    FETCH_CURSOR,
116    WAIT,
117    KILL,
118    BACKUP,
119    DELETE_META_SNAPSHOTS,
120    RECOVER,
121    USE,
122    PREPARE,
123    DEALLOCATE,
124    VACUUM,
125}
126
127impl std::fmt::Display for StatementType {
128    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
129        write!(f, "{:?}", self)
130    }
131}
132
133pub trait Callback = Future<Output = Result<(), BoxedError>> + Send;
134
135pub type BoxedCallback = Pin<Box<dyn Callback>>;
136
137pub struct PgResponse<VS> {
138    stmt_type: StatementType,
139    is_copy_query_to_stdout: bool,
140
141    // row count of affected row. Used for INSERT, UPDATE, DELETE, COPY, and other statements that
142    // don't return rows.
143    row_cnt: Option<i32>,
144    // Used for INSERT, UPDATE, DELETE to specify the format of the affected row count.
145    row_cnt_format: Option<Format>,
146    notices: Vec<String>,
147    values_stream: Option<VS>,
148    callback: Option<BoxedCallback>,
149    row_desc: Vec<PgFieldDescriptor>,
150    status: ParameterStatus,
151}
152
153pub struct PgResponseBuilder<VS> {
154    stmt_type: StatementType,
155    // row count of affected row. Used for INSERT, UPDATE, DELETE, COPY, and other statements that
156    // don't return rows.
157    row_cnt: Option<i32>,
158    // Used for INSERT, UPDATE, DELETE to specify the format of the affected row count.
159    row_cnt_format: Option<Format>,
160    notices: Vec<String>,
161    values_stream: Option<VS>,
162    callback: Option<BoxedCallback>,
163    row_desc: Vec<PgFieldDescriptor>,
164    status: ParameterStatus,
165}
166
167impl<VS> From<PgResponseBuilder<VS>> for PgResponse<VS> {
168    fn from(builder: PgResponseBuilder<VS>) -> Self {
169        Self {
170            stmt_type: builder.stmt_type,
171            is_copy_query_to_stdout: false, // set a false from builder, alter later
172            row_cnt: builder.row_cnt,
173            row_cnt_format: builder.row_cnt_format,
174            notices: builder.notices,
175            values_stream: builder.values_stream,
176            callback: builder.callback,
177            row_desc: builder.row_desc,
178            status: builder.status,
179        }
180    }
181}
182
183impl<VS> PgResponseBuilder<VS> {
184    pub fn empty(stmt_type: StatementType) -> Self {
185        let row_cnt = if stmt_type.is_query() { None } else { Some(0) };
186        Self {
187            stmt_type,
188            row_cnt,
189            row_cnt_format: None,
190            notices: vec![],
191            values_stream: None,
192            callback: None,
193            row_desc: vec![],
194            status: Default::default(),
195        }
196    }
197
198    pub fn row_cnt(self, row_cnt: i32) -> Self {
199        Self {
200            row_cnt: Some(row_cnt),
201            ..self
202        }
203    }
204
205    pub fn row_cnt_opt(self, row_cnt: Option<i32>) -> Self {
206        Self { row_cnt, ..self }
207    }
208
209    pub fn row_cnt_format_opt(self, row_cnt_format: Option<Format>) -> Self {
210        Self {
211            row_cnt_format,
212            ..self
213        }
214    }
215
216    pub fn values(self, values_stream: VS, row_desc: Vec<PgFieldDescriptor>) -> Self {
217        Self {
218            values_stream: Some(values_stream),
219            row_desc,
220            ..self
221        }
222    }
223
224    pub fn callback(self, callback: impl Callback + 'static) -> Self {
225        Self {
226            callback: Some(callback.boxed()),
227            ..self
228        }
229    }
230
231    pub fn notice(self, notice: impl ToString) -> Self {
232        let mut notices = self.notices;
233        notices.push(notice.to_string());
234        Self { notices, ..self }
235    }
236
237    pub fn status(self, status: ParameterStatus) -> Self {
238        Self { status, ..self }
239    }
240}
241
242impl<VS> std::fmt::Debug for PgResponse<VS> {
243    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
244        f.debug_struct("PgResponse")
245            .field("stmt_type", &self.stmt_type)
246            .field("row_cnt", &self.row_cnt)
247            .field("notices", &self.notices)
248            .field("row_desc", &self.row_desc)
249            .finish()
250    }
251}
252
253impl StatementType {
254    pub fn infer_from_statement(stmt: &Statement) -> Result<Self, String> {
255        match stmt {
256            Statement::Query(_) => Ok(StatementType::SELECT),
257            Statement::Insert { returning, .. } => {
258                if returning.is_empty() {
259                    Ok(StatementType::INSERT)
260                } else {
261                    Ok(StatementType::INSERT_RETURNING)
262                }
263            }
264            Statement::Delete { returning, .. } => {
265                if returning.is_empty() {
266                    Ok(StatementType::DELETE)
267                } else {
268                    Ok(StatementType::DELETE_RETURNING)
269                }
270            }
271            Statement::Update { returning, .. } => {
272                if returning.is_empty() {
273                    Ok(StatementType::UPDATE)
274                } else {
275                    Ok(StatementType::UPDATE_RETURNING)
276                }
277            }
278            Statement::Copy { .. } => Ok(StatementType::COPY),
279            Statement::CreateTable { .. } => Ok(StatementType::CREATE_TABLE),
280            Statement::CreateIndex { .. } => Ok(StatementType::CREATE_INDEX),
281            Statement::CreateSchema { .. } => Ok(StatementType::CREATE_SCHEMA),
282            Statement::CreateSource { .. } => Ok(StatementType::CREATE_SOURCE),
283            Statement::CreateSink { .. } => Ok(StatementType::CREATE_SINK),
284            Statement::CreateFunction { .. } => Ok(StatementType::CREATE_FUNCTION),
285            Statement::CreateDatabase { .. } => Ok(StatementType::CREATE_DATABASE),
286            Statement::CreateUser { .. } => Ok(StatementType::CREATE_USER),
287            Statement::CreateView { materialized, .. } => {
288                if *materialized {
289                    Ok(StatementType::CREATE_MATERIALIZED_VIEW)
290                } else {
291                    Ok(StatementType::CREATE_VIEW)
292                }
293            }
294            Statement::AlterTable { .. } => Ok(StatementType::ALTER_TABLE),
295            Statement::AlterSystem { .. } => Ok(StatementType::ALTER_SYSTEM),
296            Statement::AlterFragment { .. } => Ok(StatementType::ALTER_FRAGMENT),
297            Statement::DropFunction { .. } => Ok(StatementType::DROP_FUNCTION),
298            Statement::Discard(..) => Ok(StatementType::DISCARD),
299            Statement::SetVariable { .. } => Ok(StatementType::SET_VARIABLE),
300            Statement::ShowVariable { .. } => Ok(StatementType::SHOW_VARIABLE),
301            Statement::StartTransaction { .. } => Ok(StatementType::START_TRANSACTION),
302            Statement::Begin { .. } => Ok(StatementType::BEGIN),
303            Statement::Abort => Ok(StatementType::ABORT),
304            Statement::Commit { .. } => Ok(StatementType::COMMIT),
305            Statement::Rollback { .. } => Ok(StatementType::ROLLBACK),
306            Statement::Grant { .. } => Ok(StatementType::GRANT_PRIVILEGE),
307            Statement::Revoke { .. } => Ok(StatementType::REVOKE_PRIVILEGE),
308            Statement::Describe { .. } => Ok(StatementType::DESCRIBE),
309            Statement::ShowCreateObject { .. } | Statement::ShowObjects { .. } => {
310                Ok(StatementType::SHOW_COMMAND)
311            }
312            Statement::Drop(stmt) => match stmt.object_type {
313                risingwave_sqlparser::ast::ObjectType::Table => Ok(StatementType::DROP_TABLE),
314                risingwave_sqlparser::ast::ObjectType::View => Ok(StatementType::DROP_VIEW),
315                risingwave_sqlparser::ast::ObjectType::MaterializedView => {
316                    Ok(StatementType::DROP_MATERIALIZED_VIEW)
317                }
318                risingwave_sqlparser::ast::ObjectType::Index => Ok(StatementType::DROP_INDEX),
319                risingwave_sqlparser::ast::ObjectType::Schema => Ok(StatementType::DROP_SCHEMA),
320                risingwave_sqlparser::ast::ObjectType::Source => Ok(StatementType::DROP_SOURCE),
321                risingwave_sqlparser::ast::ObjectType::Sink => Ok(StatementType::DROP_SINK),
322                risingwave_sqlparser::ast::ObjectType::Database => Ok(StatementType::DROP_DATABASE),
323                risingwave_sqlparser::ast::ObjectType::User => Ok(StatementType::DROP_USER),
324                risingwave_sqlparser::ast::ObjectType::Connection => {
325                    Ok(StatementType::DROP_CONNECTION)
326                }
327                risingwave_sqlparser::ast::ObjectType::Secret => Ok(StatementType::DROP_SECRET),
328                risingwave_sqlparser::ast::ObjectType::Subscription => {
329                    Ok(StatementType::DROP_SUBSCRIPTION)
330                }
331            },
332            Statement::Explain { .. } => Ok(StatementType::EXPLAIN),
333            Statement::DeclareCursor { .. } => Ok(StatementType::DECLARE_CURSOR),
334            Statement::FetchCursor { .. } => Ok(StatementType::FETCH_CURSOR),
335            Statement::CloseCursor { .. } => Ok(StatementType::CLOSE_CURSOR),
336            Statement::Flush => Ok(StatementType::FLUSH),
337            Statement::Wait => Ok(StatementType::WAIT),
338            Statement::Backup => Ok(StatementType::BACKUP),
339            Statement::DeleteMetaSnapshots { .. } => Ok(StatementType::DELETE_META_SNAPSHOTS),
340            Statement::Recover => Ok(StatementType::RECOVER),
341            Statement::Use { .. } => Ok(StatementType::USE),
342            Statement::Vacuum { .. } => Ok(StatementType::VACUUM),
343            _ => Err("unsupported statement type".to_owned()),
344        }
345    }
346
347    pub fn is_command(&self) -> bool {
348        matches!(
349            self,
350            StatementType::INSERT
351                | StatementType::DELETE
352                | StatementType::UPDATE
353                | StatementType::MOVE
354                | StatementType::COPY
355                | StatementType::FETCH
356                | StatementType::SELECT
357                | StatementType::INSERT_RETURNING
358                | StatementType::DELETE_RETURNING
359                | StatementType::UPDATE_RETURNING
360        )
361    }
362
363    pub fn is_dml(&self) -> bool {
364        matches!(
365            self,
366            StatementType::INSERT
367                | StatementType::DELETE
368                | StatementType::UPDATE
369                | StatementType::INSERT_RETURNING
370                | StatementType::DELETE_RETURNING
371                | StatementType::UPDATE_RETURNING
372        )
373    }
374
375    pub fn is_query(&self) -> bool {
376        matches!(
377            self,
378            StatementType::SELECT
379                | StatementType::EXPLAIN
380                | StatementType::SHOW_COMMAND
381                | StatementType::SHOW_VARIABLE
382                | StatementType::DESCRIBE
383                | StatementType::INSERT_RETURNING
384                | StatementType::DELETE_RETURNING
385                | StatementType::UPDATE_RETURNING
386                | StatementType::CANCEL_COMMAND
387                | StatementType::BACKUP
388                | StatementType::FETCH_CURSOR
389        )
390    }
391
392    pub fn is_returning(&self) -> bool {
393        matches!(
394            self,
395            StatementType::INSERT_RETURNING
396                | StatementType::DELETE_RETURNING
397                | StatementType::UPDATE_RETURNING
398        )
399    }
400}
401
402impl<VS> PgResponse<VS>
403where
404    VS: ValuesStream,
405{
406    pub fn builder(stmt_type: StatementType) -> PgResponseBuilder<VS> {
407        PgResponseBuilder::empty(stmt_type)
408    }
409
410    pub fn empty_result(stmt_type: StatementType) -> Self {
411        PgResponseBuilder::empty(stmt_type).into()
412    }
413
414    pub fn into_copy_query_to_stdout(mut self) -> Self {
415        self.is_copy_query_to_stdout = true;
416        self.stmt_type = StatementType::COPY;
417        self
418    }
419
420    pub fn stmt_type(&self) -> StatementType {
421        self.stmt_type
422    }
423
424    pub fn notices(&self) -> &[String] {
425        &self.notices
426    }
427
428    pub fn status(&self) -> &ParameterStatus {
429        &self.status
430    }
431
432    pub fn affected_rows_cnt(&self) -> Option<i32> {
433        self.row_cnt
434    }
435
436    pub fn row_cnt_format(&self) -> Option<Format> {
437        self.row_cnt_format
438    }
439
440    pub fn is_query(&self) -> bool {
441        self.stmt_type.is_query()
442    }
443
444    pub fn is_empty(&self) -> bool {
445        self.stmt_type == StatementType::EMPTY
446    }
447
448    pub fn is_copy_query_to_stdout(&self) -> bool {
449        self.is_copy_query_to_stdout
450    }
451
452    pub fn row_desc(&self) -> &[PgFieldDescriptor] {
453        &self.row_desc
454    }
455
456    pub fn values_stream(&mut self) -> &mut VS {
457        self.values_stream.as_mut().expect("no values stream")
458    }
459
460    /// Run the callback if there is one.
461    ///
462    /// This should only be called after the values stream has been exhausted. Multiple calls to
463    /// this function will be no-ops.
464    pub async fn run_callback(&mut self) -> Result<(), PsqlError> {
465        // Check if the stream is exhausted.
466        if let Some(values_stream) = &mut self.values_stream {
467            assert!(values_stream.next().await.is_none());
468        }
469
470        if let Some(callback) = self.callback.take() {
471            callback.await.map_err(PsqlError::SimpleQueryError)?;
472        }
473        Ok(())
474    }
475}