pgwire/
pg_response.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Formatter;
16use std::pin::Pin;
17
18use futures::{Future, FutureExt, Stream, StreamExt};
19use risingwave_sqlparser::ast::Statement;
20
21use crate::error::PsqlError;
22use crate::pg_field_descriptor::PgFieldDescriptor;
23use crate::pg_protocol::ParameterStatus;
24use crate::pg_server::BoxedError;
25use crate::types::{Format, Row};
26
27pub type RowSet = Vec<Row>;
28pub type RowSetResult = Result<RowSet, BoxedError>;
29
30pub trait ValuesStream = Stream<Item = RowSetResult> + Unpin + Send;
31
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
34pub enum StatementType {
35    INSERT,
36    INSERT_RETURNING,
37    DELETE,
38    DELETE_RETURNING,
39    UPDATE,
40    UPDATE_RETURNING,
41    SELECT,
42    MOVE,
43    FETCH,
44    COPY,
45    EXPLAIN,
46    CLOSE_CURSOR,
47    CREATE_TABLE,
48    CREATE_MATERIALIZED_VIEW,
49    CREATE_VIEW,
50    CREATE_SOURCE,
51    CREATE_SINK,
52    CREATE_SUBSCRIPTION,
53    CREATE_DATABASE,
54    CREATE_SCHEMA,
55    CREATE_USER,
56    CREATE_INDEX,
57    CREATE_AGGREGATE,
58    CREATE_FUNCTION,
59    CREATE_CONNECTION,
60    CREATE_SECRET,
61    COMMENT,
62    DECLARE_CURSOR,
63    DESCRIBE,
64    GRANT_PRIVILEGE,
65    DISCARD,
66    DROP_TABLE,
67    DROP_MATERIALIZED_VIEW,
68    DROP_VIEW,
69    DROP_INDEX,
70    DROP_FUNCTION,
71    DROP_AGGREGATE,
72    DROP_SOURCE,
73    DROP_SINK,
74    DROP_SUBSCRIPTION,
75    DROP_SCHEMA,
76    DROP_DATABASE,
77    DROP_USER,
78    DROP_CONNECTION,
79    DROP_SECRET,
80    ALTER_DATABASE,
81    ALTER_DEFAULT_PRIVILEGES,
82    ALTER_SCHEMA,
83    ALTER_INDEX,
84    ALTER_VIEW,
85    ALTER_TABLE,
86    ALTER_MATERIALIZED_VIEW,
87    ALTER_SINK,
88    ALTER_SUBSCRIPTION,
89    ALTER_SOURCE,
90    ALTER_FUNCTION,
91    ALTER_CONNECTION,
92    ALTER_SYSTEM,
93    ALTER_SECRET,
94    REVOKE_PRIVILEGE,
95    // Introduce ORDER_BY statement type cuz Calcite unvalidated AST has SqlKind.ORDER_BY. Note
96    // that Statement Type is not designed to be one to one mapping with SqlKind.
97    ORDER_BY,
98    SET_VARIABLE,
99    SHOW_VARIABLE,
100    SHOW_COMMAND,
101    START_TRANSACTION,
102    UPDATE_USER,
103    ABORT,
104    FLUSH,
105    OTHER,
106    // EMPTY is used when query statement is empty (e.g. ";").
107    EMPTY,
108    BEGIN,
109    COMMIT,
110    ROLLBACK,
111    SET_TRANSACTION,
112    CANCEL_COMMAND,
113    FETCH_CURSOR,
114    WAIT,
115    KILL,
116    RECOVER,
117    USE,
118    PREPARE,
119    DEALLOCATE,
120}
121
122impl std::fmt::Display for StatementType {
123    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
124        write!(f, "{:?}", self)
125    }
126}
127
128pub trait Callback = Future<Output = Result<(), BoxedError>> + Send;
129
130pub type BoxedCallback = Pin<Box<dyn Callback>>;
131
132pub struct PgResponse<VS> {
133    stmt_type: StatementType,
134    // row count of affected row. Used for INSERT, UPDATE, DELETE, COPY, and other statements that
135    // don't return rows.
136    row_cnt: Option<i32>,
137    // Used for INSERT, UPDATE, DELETE to specify the format of the affected row count.
138    row_cnt_format: Option<Format>,
139    notices: Vec<String>,
140    values_stream: Option<VS>,
141    callback: Option<BoxedCallback>,
142    row_desc: Vec<PgFieldDescriptor>,
143    status: ParameterStatus,
144}
145
146pub struct PgResponseBuilder<VS> {
147    stmt_type: StatementType,
148    // row count of affected row. Used for INSERT, UPDATE, DELETE, COPY, and other statements that
149    // don't return rows.
150    row_cnt: Option<i32>,
151    // Used for INSERT, UPDATE, DELETE to specify the format of the affected row count.
152    row_cnt_format: Option<Format>,
153    notices: Vec<String>,
154    values_stream: Option<VS>,
155    callback: Option<BoxedCallback>,
156    row_desc: Vec<PgFieldDescriptor>,
157    status: ParameterStatus,
158}
159
160impl<VS> From<PgResponseBuilder<VS>> for PgResponse<VS> {
161    fn from(builder: PgResponseBuilder<VS>) -> Self {
162        Self {
163            stmt_type: builder.stmt_type,
164            row_cnt: builder.row_cnt,
165            row_cnt_format: builder.row_cnt_format,
166            notices: builder.notices,
167            values_stream: builder.values_stream,
168            callback: builder.callback,
169            row_desc: builder.row_desc,
170            status: builder.status,
171        }
172    }
173}
174
175impl<VS> PgResponseBuilder<VS> {
176    pub fn empty(stmt_type: StatementType) -> Self {
177        let row_cnt = if stmt_type.is_query() { None } else { Some(0) };
178        Self {
179            stmt_type,
180            row_cnt,
181            row_cnt_format: None,
182            notices: vec![],
183            values_stream: None,
184            callback: None,
185            row_desc: vec![],
186            status: Default::default(),
187        }
188    }
189
190    pub fn row_cnt(self, row_cnt: i32) -> Self {
191        Self {
192            row_cnt: Some(row_cnt),
193            ..self
194        }
195    }
196
197    pub fn row_cnt_opt(self, row_cnt: Option<i32>) -> Self {
198        Self { row_cnt, ..self }
199    }
200
201    pub fn row_cnt_format_opt(self, row_cnt_format: Option<Format>) -> Self {
202        Self {
203            row_cnt_format,
204            ..self
205        }
206    }
207
208    pub fn values(self, values_stream: VS, row_desc: Vec<PgFieldDescriptor>) -> Self {
209        Self {
210            values_stream: Some(values_stream),
211            row_desc,
212            ..self
213        }
214    }
215
216    pub fn callback(self, callback: impl Callback + 'static) -> Self {
217        Self {
218            callback: Some(callback.boxed()),
219            ..self
220        }
221    }
222
223    pub fn notice(self, notice: impl ToString) -> Self {
224        let mut notices = self.notices;
225        notices.push(notice.to_string());
226        Self { notices, ..self }
227    }
228
229    pub fn status(self, status: ParameterStatus) -> Self {
230        Self { status, ..self }
231    }
232}
233
234impl<VS> std::fmt::Debug for PgResponse<VS> {
235    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
236        f.debug_struct("PgResponse")
237            .field("stmt_type", &self.stmt_type)
238            .field("row_cnt", &self.row_cnt)
239            .field("notices", &self.notices)
240            .field("row_desc", &self.row_desc)
241            .finish()
242    }
243}
244
245impl StatementType {
246    pub fn infer_from_statement(stmt: &Statement) -> Result<Self, String> {
247        match stmt {
248            Statement::Query(_) => Ok(StatementType::SELECT),
249            Statement::Insert { returning, .. } => {
250                if returning.is_empty() {
251                    Ok(StatementType::INSERT)
252                } else {
253                    Ok(StatementType::INSERT_RETURNING)
254                }
255            }
256            Statement::Delete { returning, .. } => {
257                if returning.is_empty() {
258                    Ok(StatementType::DELETE)
259                } else {
260                    Ok(StatementType::DELETE_RETURNING)
261                }
262            }
263            Statement::Update { returning, .. } => {
264                if returning.is_empty() {
265                    Ok(StatementType::UPDATE)
266                } else {
267                    Ok(StatementType::UPDATE_RETURNING)
268                }
269            }
270            Statement::Copy { .. } => Ok(StatementType::COPY),
271            Statement::CreateTable { .. } => Ok(StatementType::CREATE_TABLE),
272            Statement::CreateIndex { .. } => Ok(StatementType::CREATE_INDEX),
273            Statement::CreateSchema { .. } => Ok(StatementType::CREATE_SCHEMA),
274            Statement::CreateSource { .. } => Ok(StatementType::CREATE_SOURCE),
275            Statement::CreateSink { .. } => Ok(StatementType::CREATE_SINK),
276            Statement::CreateFunction { .. } => Ok(StatementType::CREATE_FUNCTION),
277            Statement::CreateDatabase { .. } => Ok(StatementType::CREATE_DATABASE),
278            Statement::CreateUser { .. } => Ok(StatementType::CREATE_USER),
279            Statement::CreateView { materialized, .. } => {
280                if *materialized {
281                    Ok(StatementType::CREATE_MATERIALIZED_VIEW)
282                } else {
283                    Ok(StatementType::CREATE_VIEW)
284                }
285            }
286            Statement::AlterTable { .. } => Ok(StatementType::ALTER_TABLE),
287            Statement::AlterSystem { .. } => Ok(StatementType::ALTER_SYSTEM),
288            Statement::DropFunction { .. } => Ok(StatementType::DROP_FUNCTION),
289            Statement::Discard(..) => Ok(StatementType::DISCARD),
290            Statement::SetVariable { .. } => Ok(StatementType::SET_VARIABLE),
291            Statement::ShowVariable { .. } => Ok(StatementType::SHOW_VARIABLE),
292            Statement::StartTransaction { .. } => Ok(StatementType::START_TRANSACTION),
293            Statement::Begin { .. } => Ok(StatementType::BEGIN),
294            Statement::Abort => Ok(StatementType::ABORT),
295            Statement::Commit { .. } => Ok(StatementType::COMMIT),
296            Statement::Rollback { .. } => Ok(StatementType::ROLLBACK),
297            Statement::Grant { .. } => Ok(StatementType::GRANT_PRIVILEGE),
298            Statement::Revoke { .. } => Ok(StatementType::REVOKE_PRIVILEGE),
299            Statement::Describe { .. } => Ok(StatementType::DESCRIBE),
300            Statement::ShowCreateObject { .. } | Statement::ShowObjects { .. } => {
301                Ok(StatementType::SHOW_COMMAND)
302            }
303            Statement::Drop(stmt) => match stmt.object_type {
304                risingwave_sqlparser::ast::ObjectType::Table => Ok(StatementType::DROP_TABLE),
305                risingwave_sqlparser::ast::ObjectType::View => Ok(StatementType::DROP_VIEW),
306                risingwave_sqlparser::ast::ObjectType::MaterializedView => {
307                    Ok(StatementType::DROP_MATERIALIZED_VIEW)
308                }
309                risingwave_sqlparser::ast::ObjectType::Index => Ok(StatementType::DROP_INDEX),
310                risingwave_sqlparser::ast::ObjectType::Schema => Ok(StatementType::DROP_SCHEMA),
311                risingwave_sqlparser::ast::ObjectType::Source => Ok(StatementType::DROP_SOURCE),
312                risingwave_sqlparser::ast::ObjectType::Sink => Ok(StatementType::DROP_SINK),
313                risingwave_sqlparser::ast::ObjectType::Database => Ok(StatementType::DROP_DATABASE),
314                risingwave_sqlparser::ast::ObjectType::User => Ok(StatementType::DROP_USER),
315                risingwave_sqlparser::ast::ObjectType::Connection => {
316                    Ok(StatementType::DROP_CONNECTION)
317                }
318                risingwave_sqlparser::ast::ObjectType::Secret => Ok(StatementType::DROP_SECRET),
319                risingwave_sqlparser::ast::ObjectType::Subscription => {
320                    Ok(StatementType::DROP_SUBSCRIPTION)
321                }
322            },
323            Statement::Explain { .. } => Ok(StatementType::EXPLAIN),
324            Statement::DeclareCursor { .. } => Ok(StatementType::DECLARE_CURSOR),
325            Statement::FetchCursor { .. } => Ok(StatementType::FETCH_CURSOR),
326            Statement::CloseCursor { .. } => Ok(StatementType::CLOSE_CURSOR),
327            Statement::Flush => Ok(StatementType::FLUSH),
328            Statement::Wait => Ok(StatementType::WAIT),
329            Statement::Use { .. } => Ok(StatementType::USE),
330            _ => Err("unsupported statement type".to_owned()),
331        }
332    }
333
334    pub fn is_command(&self) -> bool {
335        matches!(
336            self,
337            StatementType::INSERT
338                | StatementType::DELETE
339                | StatementType::UPDATE
340                | StatementType::MOVE
341                | StatementType::COPY
342                | StatementType::FETCH
343                | StatementType::SELECT
344                | StatementType::INSERT_RETURNING
345                | StatementType::DELETE_RETURNING
346                | StatementType::UPDATE_RETURNING
347        )
348    }
349
350    pub fn is_dml(&self) -> bool {
351        matches!(
352            self,
353            StatementType::INSERT
354                | StatementType::DELETE
355                | StatementType::UPDATE
356                | StatementType::INSERT_RETURNING
357                | StatementType::DELETE_RETURNING
358                | StatementType::UPDATE_RETURNING
359        )
360    }
361
362    pub fn is_query(&self) -> bool {
363        matches!(
364            self,
365            StatementType::SELECT
366                | StatementType::EXPLAIN
367                | StatementType::SHOW_COMMAND
368                | StatementType::SHOW_VARIABLE
369                | StatementType::DESCRIBE
370                | StatementType::INSERT_RETURNING
371                | StatementType::DELETE_RETURNING
372                | StatementType::UPDATE_RETURNING
373                | StatementType::CANCEL_COMMAND
374                | StatementType::FETCH_CURSOR
375        )
376    }
377
378    pub fn is_returning(&self) -> bool {
379        matches!(
380            self,
381            StatementType::INSERT_RETURNING
382                | StatementType::DELETE_RETURNING
383                | StatementType::UPDATE_RETURNING
384        )
385    }
386}
387
388impl<VS> PgResponse<VS>
389where
390    VS: ValuesStream,
391{
392    pub fn builder(stmt_type: StatementType) -> PgResponseBuilder<VS> {
393        PgResponseBuilder::empty(stmt_type)
394    }
395
396    pub fn empty_result(stmt_type: StatementType) -> Self {
397        PgResponseBuilder::empty(stmt_type).into()
398    }
399
400    pub fn stmt_type(&self) -> StatementType {
401        self.stmt_type
402    }
403
404    pub fn notices(&self) -> &[String] {
405        &self.notices
406    }
407
408    pub fn status(&self) -> &ParameterStatus {
409        &self.status
410    }
411
412    pub fn affected_rows_cnt(&self) -> Option<i32> {
413        self.row_cnt
414    }
415
416    pub fn row_cnt_format(&self) -> Option<Format> {
417        self.row_cnt_format
418    }
419
420    pub fn is_query(&self) -> bool {
421        self.stmt_type.is_query()
422    }
423
424    pub fn is_empty(&self) -> bool {
425        self.stmt_type == StatementType::EMPTY
426    }
427
428    pub fn row_desc(&self) -> Vec<PgFieldDescriptor> {
429        self.row_desc.clone()
430    }
431
432    pub fn values_stream(&mut self) -> &mut VS {
433        self.values_stream.as_mut().expect("no values stream")
434    }
435
436    /// Run the callback if there is one.
437    ///
438    /// This should only be called after the values stream has been exhausted. Multiple calls to
439    /// this function will be no-ops.
440    pub async fn run_callback(&mut self) -> Result<(), PsqlError> {
441        // Check if the stream is exhausted.
442        if let Some(values_stream) = &mut self.values_stream {
443            assert!(values_stream.next().await.is_none());
444        }
445
446        if let Some(callback) = self.callback.take() {
447            callback.await.map_err(PsqlError::SimpleQueryError)?;
448        }
449        Ok(())
450    }
451}