pgwire/
pg_response.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Formatter;
16use std::pin::Pin;
17
18use futures::{Future, FutureExt, Stream, StreamExt};
19use risingwave_sqlparser::ast::Statement;
20
21use crate::error::PsqlError;
22use crate::pg_field_descriptor::PgFieldDescriptor;
23use crate::pg_protocol::ParameterStatus;
24use crate::pg_server::BoxedError;
25use crate::types::{Format, Row};
26
27pub type RowSet = Vec<Row>;
28pub type RowSetResult = Result<RowSet, BoxedError>;
29
30pub trait ValuesStream = Stream<Item = RowSetResult> + Unpin + Send;
31
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
34pub enum StatementType {
35    INSERT,
36    INSERT_RETURNING,
37    DELETE,
38    DELETE_RETURNING,
39    UPDATE,
40    UPDATE_RETURNING,
41    SELECT,
42    MOVE,
43    FETCH,
44    COPY,
45    EXPLAIN,
46    CLOSE_CURSOR,
47    CREATE_TABLE,
48    CREATE_MATERIALIZED_VIEW,
49    CREATE_VIEW,
50    CREATE_SOURCE,
51    CREATE_SINK,
52    CREATE_SUBSCRIPTION,
53    CREATE_DATABASE,
54    CREATE_SCHEMA,
55    CREATE_USER,
56    CREATE_INDEX,
57    CREATE_AGGREGATE,
58    CREATE_FUNCTION,
59    CREATE_CONNECTION,
60    CREATE_SECRET,
61    COMMENT,
62    DECLARE_CURSOR,
63    DESCRIBE,
64    GRANT_PRIVILEGE,
65    DISCARD,
66    DROP_TABLE,
67    DROP_MATERIALIZED_VIEW,
68    DROP_VIEW,
69    DROP_INDEX,
70    DROP_FUNCTION,
71    DROP_AGGREGATE,
72    DROP_SOURCE,
73    DROP_SINK,
74    DROP_SUBSCRIPTION,
75    DROP_SCHEMA,
76    DROP_DATABASE,
77    DROP_USER,
78    DROP_CONNECTION,
79    DROP_SECRET,
80    ALTER_DATABASE,
81    ALTER_DEFAULT_PRIVILEGES,
82    ALTER_SCHEMA,
83    ALTER_INDEX,
84    ALTER_VIEW,
85    ALTER_TABLE,
86    ALTER_MATERIALIZED_VIEW,
87    ALTER_SINK,
88    ALTER_SUBSCRIPTION,
89    ALTER_SOURCE,
90    ALTER_FUNCTION,
91    ALTER_CONNECTION,
92    ALTER_SYSTEM,
93    ALTER_SECRET,
94    REVOKE_PRIVILEGE,
95    // Introduce ORDER_BY statement type cuz Calcite unvalidated AST has SqlKind.ORDER_BY. Note
96    // that Statement Type is not designed to be one to one mapping with SqlKind.
97    ORDER_BY,
98    SET_VARIABLE,
99    SHOW_VARIABLE,
100    SHOW_COMMAND,
101    START_TRANSACTION,
102    UPDATE_USER,
103    ABORT,
104    FLUSH,
105    REFRESH_TABLE,
106    OTHER,
107    // EMPTY is used when query statement is empty (e.g. ";").
108    EMPTY,
109    BEGIN,
110    COMMIT,
111    ROLLBACK,
112    SET_TRANSACTION,
113    CANCEL_COMMAND,
114    FETCH_CURSOR,
115    WAIT,
116    KILL,
117    RECOVER,
118    USE,
119    PREPARE,
120    DEALLOCATE,
121    VACUUM,
122}
123
124impl std::fmt::Display for StatementType {
125    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
126        write!(f, "{:?}", self)
127    }
128}
129
130pub trait Callback = Future<Output = Result<(), BoxedError>> + Send;
131
132pub type BoxedCallback = Pin<Box<dyn Callback>>;
133
134pub struct PgResponse<VS> {
135    stmt_type: StatementType,
136    // row count of affected row. Used for INSERT, UPDATE, DELETE, COPY, and other statements that
137    // don't return rows.
138    row_cnt: Option<i32>,
139    // Used for INSERT, UPDATE, DELETE to specify the format of the affected row count.
140    row_cnt_format: Option<Format>,
141    notices: Vec<String>,
142    values_stream: Option<VS>,
143    callback: Option<BoxedCallback>,
144    row_desc: Vec<PgFieldDescriptor>,
145    status: ParameterStatus,
146}
147
148pub struct PgResponseBuilder<VS> {
149    stmt_type: StatementType,
150    // row count of affected row. Used for INSERT, UPDATE, DELETE, COPY, and other statements that
151    // don't return rows.
152    row_cnt: Option<i32>,
153    // Used for INSERT, UPDATE, DELETE to specify the format of the affected row count.
154    row_cnt_format: Option<Format>,
155    notices: Vec<String>,
156    values_stream: Option<VS>,
157    callback: Option<BoxedCallback>,
158    row_desc: Vec<PgFieldDescriptor>,
159    status: ParameterStatus,
160}
161
162impl<VS> From<PgResponseBuilder<VS>> for PgResponse<VS> {
163    fn from(builder: PgResponseBuilder<VS>) -> Self {
164        Self {
165            stmt_type: builder.stmt_type,
166            row_cnt: builder.row_cnt,
167            row_cnt_format: builder.row_cnt_format,
168            notices: builder.notices,
169            values_stream: builder.values_stream,
170            callback: builder.callback,
171            row_desc: builder.row_desc,
172            status: builder.status,
173        }
174    }
175}
176
177impl<VS> PgResponseBuilder<VS> {
178    pub fn empty(stmt_type: StatementType) -> Self {
179        let row_cnt = if stmt_type.is_query() { None } else { Some(0) };
180        Self {
181            stmt_type,
182            row_cnt,
183            row_cnt_format: None,
184            notices: vec![],
185            values_stream: None,
186            callback: None,
187            row_desc: vec![],
188            status: Default::default(),
189        }
190    }
191
192    pub fn row_cnt(self, row_cnt: i32) -> Self {
193        Self {
194            row_cnt: Some(row_cnt),
195            ..self
196        }
197    }
198
199    pub fn row_cnt_opt(self, row_cnt: Option<i32>) -> Self {
200        Self { row_cnt, ..self }
201    }
202
203    pub fn row_cnt_format_opt(self, row_cnt_format: Option<Format>) -> Self {
204        Self {
205            row_cnt_format,
206            ..self
207        }
208    }
209
210    pub fn values(self, values_stream: VS, row_desc: Vec<PgFieldDescriptor>) -> Self {
211        Self {
212            values_stream: Some(values_stream),
213            row_desc,
214            ..self
215        }
216    }
217
218    pub fn callback(self, callback: impl Callback + 'static) -> Self {
219        Self {
220            callback: Some(callback.boxed()),
221            ..self
222        }
223    }
224
225    pub fn notice(self, notice: impl ToString) -> Self {
226        let mut notices = self.notices;
227        notices.push(notice.to_string());
228        Self { notices, ..self }
229    }
230
231    pub fn status(self, status: ParameterStatus) -> Self {
232        Self { status, ..self }
233    }
234}
235
236impl<VS> std::fmt::Debug for PgResponse<VS> {
237    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
238        f.debug_struct("PgResponse")
239            .field("stmt_type", &self.stmt_type)
240            .field("row_cnt", &self.row_cnt)
241            .field("notices", &self.notices)
242            .field("row_desc", &self.row_desc)
243            .finish()
244    }
245}
246
247impl StatementType {
248    pub fn infer_from_statement(stmt: &Statement) -> Result<Self, String> {
249        match stmt {
250            Statement::Query(_) => Ok(StatementType::SELECT),
251            Statement::Insert { returning, .. } => {
252                if returning.is_empty() {
253                    Ok(StatementType::INSERT)
254                } else {
255                    Ok(StatementType::INSERT_RETURNING)
256                }
257            }
258            Statement::Delete { returning, .. } => {
259                if returning.is_empty() {
260                    Ok(StatementType::DELETE)
261                } else {
262                    Ok(StatementType::DELETE_RETURNING)
263                }
264            }
265            Statement::Update { returning, .. } => {
266                if returning.is_empty() {
267                    Ok(StatementType::UPDATE)
268                } else {
269                    Ok(StatementType::UPDATE_RETURNING)
270                }
271            }
272            Statement::Copy { .. } => Ok(StatementType::COPY),
273            Statement::CreateTable { .. } => Ok(StatementType::CREATE_TABLE),
274            Statement::CreateIndex { .. } => Ok(StatementType::CREATE_INDEX),
275            Statement::CreateSchema { .. } => Ok(StatementType::CREATE_SCHEMA),
276            Statement::CreateSource { .. } => Ok(StatementType::CREATE_SOURCE),
277            Statement::CreateSink { .. } => Ok(StatementType::CREATE_SINK),
278            Statement::CreateFunction { .. } => Ok(StatementType::CREATE_FUNCTION),
279            Statement::CreateDatabase { .. } => Ok(StatementType::CREATE_DATABASE),
280            Statement::CreateUser { .. } => Ok(StatementType::CREATE_USER),
281            Statement::CreateView { materialized, .. } => {
282                if *materialized {
283                    Ok(StatementType::CREATE_MATERIALIZED_VIEW)
284                } else {
285                    Ok(StatementType::CREATE_VIEW)
286                }
287            }
288            Statement::AlterTable { .. } => Ok(StatementType::ALTER_TABLE),
289            Statement::AlterSystem { .. } => Ok(StatementType::ALTER_SYSTEM),
290            Statement::DropFunction { .. } => Ok(StatementType::DROP_FUNCTION),
291            Statement::Discard(..) => Ok(StatementType::DISCARD),
292            Statement::SetVariable { .. } => Ok(StatementType::SET_VARIABLE),
293            Statement::ShowVariable { .. } => Ok(StatementType::SHOW_VARIABLE),
294            Statement::StartTransaction { .. } => Ok(StatementType::START_TRANSACTION),
295            Statement::Begin { .. } => Ok(StatementType::BEGIN),
296            Statement::Abort => Ok(StatementType::ABORT),
297            Statement::Commit { .. } => Ok(StatementType::COMMIT),
298            Statement::Rollback { .. } => Ok(StatementType::ROLLBACK),
299            Statement::Grant { .. } => Ok(StatementType::GRANT_PRIVILEGE),
300            Statement::Revoke { .. } => Ok(StatementType::REVOKE_PRIVILEGE),
301            Statement::Describe { .. } => Ok(StatementType::DESCRIBE),
302            Statement::ShowCreateObject { .. } | Statement::ShowObjects { .. } => {
303                Ok(StatementType::SHOW_COMMAND)
304            }
305            Statement::Drop(stmt) => match stmt.object_type {
306                risingwave_sqlparser::ast::ObjectType::Table => Ok(StatementType::DROP_TABLE),
307                risingwave_sqlparser::ast::ObjectType::View => Ok(StatementType::DROP_VIEW),
308                risingwave_sqlparser::ast::ObjectType::MaterializedView => {
309                    Ok(StatementType::DROP_MATERIALIZED_VIEW)
310                }
311                risingwave_sqlparser::ast::ObjectType::Index => Ok(StatementType::DROP_INDEX),
312                risingwave_sqlparser::ast::ObjectType::Schema => Ok(StatementType::DROP_SCHEMA),
313                risingwave_sqlparser::ast::ObjectType::Source => Ok(StatementType::DROP_SOURCE),
314                risingwave_sqlparser::ast::ObjectType::Sink => Ok(StatementType::DROP_SINK),
315                risingwave_sqlparser::ast::ObjectType::Database => Ok(StatementType::DROP_DATABASE),
316                risingwave_sqlparser::ast::ObjectType::User => Ok(StatementType::DROP_USER),
317                risingwave_sqlparser::ast::ObjectType::Connection => {
318                    Ok(StatementType::DROP_CONNECTION)
319                }
320                risingwave_sqlparser::ast::ObjectType::Secret => Ok(StatementType::DROP_SECRET),
321                risingwave_sqlparser::ast::ObjectType::Subscription => {
322                    Ok(StatementType::DROP_SUBSCRIPTION)
323                }
324            },
325            Statement::Explain { .. } => Ok(StatementType::EXPLAIN),
326            Statement::DeclareCursor { .. } => Ok(StatementType::DECLARE_CURSOR),
327            Statement::FetchCursor { .. } => Ok(StatementType::FETCH_CURSOR),
328            Statement::CloseCursor { .. } => Ok(StatementType::CLOSE_CURSOR),
329            Statement::Flush => Ok(StatementType::FLUSH),
330            Statement::Wait => Ok(StatementType::WAIT),
331            Statement::Use { .. } => Ok(StatementType::USE),
332            Statement::Vacuum { .. } => Ok(StatementType::VACUUM),
333            _ => Err("unsupported statement type".to_owned()),
334        }
335    }
336
337    pub fn is_command(&self) -> bool {
338        matches!(
339            self,
340            StatementType::INSERT
341                | StatementType::DELETE
342                | StatementType::UPDATE
343                | StatementType::MOVE
344                | StatementType::COPY
345                | StatementType::FETCH
346                | StatementType::SELECT
347                | StatementType::INSERT_RETURNING
348                | StatementType::DELETE_RETURNING
349                | StatementType::UPDATE_RETURNING
350        )
351    }
352
353    pub fn is_dml(&self) -> bool {
354        matches!(
355            self,
356            StatementType::INSERT
357                | StatementType::DELETE
358                | StatementType::UPDATE
359                | StatementType::INSERT_RETURNING
360                | StatementType::DELETE_RETURNING
361                | StatementType::UPDATE_RETURNING
362        )
363    }
364
365    pub fn is_query(&self) -> bool {
366        matches!(
367            self,
368            StatementType::SELECT
369                | StatementType::EXPLAIN
370                | StatementType::SHOW_COMMAND
371                | StatementType::SHOW_VARIABLE
372                | StatementType::DESCRIBE
373                | StatementType::INSERT_RETURNING
374                | StatementType::DELETE_RETURNING
375                | StatementType::UPDATE_RETURNING
376                | StatementType::CANCEL_COMMAND
377                | StatementType::FETCH_CURSOR
378        )
379    }
380
381    pub fn is_returning(&self) -> bool {
382        matches!(
383            self,
384            StatementType::INSERT_RETURNING
385                | StatementType::DELETE_RETURNING
386                | StatementType::UPDATE_RETURNING
387        )
388    }
389}
390
391impl<VS> PgResponse<VS>
392where
393    VS: ValuesStream,
394{
395    pub fn builder(stmt_type: StatementType) -> PgResponseBuilder<VS> {
396        PgResponseBuilder::empty(stmt_type)
397    }
398
399    pub fn empty_result(stmt_type: StatementType) -> Self {
400        PgResponseBuilder::empty(stmt_type).into()
401    }
402
403    pub fn stmt_type(&self) -> StatementType {
404        self.stmt_type
405    }
406
407    pub fn notices(&self) -> &[String] {
408        &self.notices
409    }
410
411    pub fn status(&self) -> &ParameterStatus {
412        &self.status
413    }
414
415    pub fn affected_rows_cnt(&self) -> Option<i32> {
416        self.row_cnt
417    }
418
419    pub fn row_cnt_format(&self) -> Option<Format> {
420        self.row_cnt_format
421    }
422
423    pub fn is_query(&self) -> bool {
424        self.stmt_type.is_query()
425    }
426
427    pub fn is_empty(&self) -> bool {
428        self.stmt_type == StatementType::EMPTY
429    }
430
431    pub fn row_desc(&self) -> Vec<PgFieldDescriptor> {
432        self.row_desc.clone()
433    }
434
435    pub fn values_stream(&mut self) -> &mut VS {
436        self.values_stream.as_mut().expect("no values stream")
437    }
438
439    /// Run the callback if there is one.
440    ///
441    /// This should only be called after the values stream has been exhausted. Multiple calls to
442    /// this function will be no-ops.
443    pub async fn run_callback(&mut self) -> Result<(), PsqlError> {
444        // Check if the stream is exhausted.
445        if let Some(values_stream) = &mut self.values_stream {
446            assert!(values_stream.next().await.is_none());
447        }
448
449        if let Some(callback) = self.callback.take() {
450            callback.await.map_err(PsqlError::SimpleQueryError)?;
451        }
452        Ok(())
453    }
454}