pgwire/
pg_response.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Formatter;
16use std::pin::Pin;
17
18use futures::{Future, FutureExt, Stream, StreamExt};
19use risingwave_sqlparser::ast::Statement;
20
21use crate::error::PsqlError;
22use crate::pg_field_descriptor::PgFieldDescriptor;
23use crate::pg_protocol::ParameterStatus;
24use crate::pg_server::BoxedError;
25use crate::types::{Format, Row};
26
27pub type RowSet = Vec<Row>;
28pub type RowSetResult = Result<RowSet, BoxedError>;
29
30pub trait ValuesStream = Stream<Item = RowSetResult> + Unpin + Send;
31
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
34pub enum StatementType {
35    INSERT,
36    INSERT_RETURNING,
37    DELETE,
38    DELETE_RETURNING,
39    UPDATE,
40    UPDATE_RETURNING,
41    SELECT,
42    MOVE,
43    FETCH,
44    COPY,
45    EXPLAIN,
46    CLOSE_CURSOR,
47    CREATE_TABLE,
48    CREATE_MATERIALIZED_VIEW,
49    CREATE_VIEW,
50    CREATE_SOURCE,
51    CREATE_SINK,
52    CREATE_SUBSCRIPTION,
53    CREATE_DATABASE,
54    CREATE_SCHEMA,
55    CREATE_USER,
56    CREATE_INDEX,
57    CREATE_AGGREGATE,
58    CREATE_FUNCTION,
59    CREATE_CONNECTION,
60    CREATE_SECRET,
61    COMMENT,
62    DECLARE_CURSOR,
63    DESCRIBE,
64    GRANT_PRIVILEGE,
65    DISCARD,
66    DROP_TABLE,
67    DROP_MATERIALIZED_VIEW,
68    DROP_VIEW,
69    DROP_INDEX,
70    DROP_FUNCTION,
71    DROP_AGGREGATE,
72    DROP_SOURCE,
73    DROP_SINK,
74    DROP_SUBSCRIPTION,
75    DROP_SCHEMA,
76    DROP_DATABASE,
77    DROP_USER,
78    DROP_CONNECTION,
79    DROP_SECRET,
80    ALTER_DATABASE,
81    ALTER_SCHEMA,
82    ALTER_INDEX,
83    ALTER_VIEW,
84    ALTER_TABLE,
85    ALTER_MATERIALIZED_VIEW,
86    ALTER_SINK,
87    ALTER_SUBSCRIPTION,
88    ALTER_SOURCE,
89    ALTER_FUNCTION,
90    ALTER_CONNECTION,
91    ALTER_SYSTEM,
92    ALTER_SECRET,
93    REVOKE_PRIVILEGE,
94    // Introduce ORDER_BY statement type cuz Calcite unvalidated AST has SqlKind.ORDER_BY. Note
95    // that Statement Type is not designed to be one to one mapping with SqlKind.
96    ORDER_BY,
97    SET_VARIABLE,
98    SHOW_VARIABLE,
99    SHOW_COMMAND,
100    START_TRANSACTION,
101    UPDATE_USER,
102    ABORT,
103    FLUSH,
104    OTHER,
105    // EMPTY is used when query statement is empty (e.g. ";").
106    EMPTY,
107    BEGIN,
108    COMMIT,
109    ROLLBACK,
110    SET_TRANSACTION,
111    CANCEL_COMMAND,
112    FETCH_CURSOR,
113    WAIT,
114    KILL,
115    RECOVER,
116    USE,
117}
118
119impl std::fmt::Display for StatementType {
120    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
121        write!(f, "{:?}", self)
122    }
123}
124
125pub trait Callback = Future<Output = Result<(), BoxedError>> + Send;
126
127pub type BoxedCallback = Pin<Box<dyn Callback>>;
128
129pub struct PgResponse<VS> {
130    stmt_type: StatementType,
131    // row count of affected row. Used for INSERT, UPDATE, DELETE, COPY, and other statements that
132    // don't return rows.
133    row_cnt: Option<i32>,
134    // Used for INSERT, UPDATE, DELETE to specify the format of the affected row count.
135    row_cnt_format: Option<Format>,
136    notices: Vec<String>,
137    values_stream: Option<VS>,
138    callback: Option<BoxedCallback>,
139    row_desc: Vec<PgFieldDescriptor>,
140    status: ParameterStatus,
141}
142
143pub struct PgResponseBuilder<VS> {
144    stmt_type: StatementType,
145    // row count of affected row. Used for INSERT, UPDATE, DELETE, COPY, and other statements that
146    // don't return rows.
147    row_cnt: Option<i32>,
148    // Used for INSERT, UPDATE, DELETE to specify the format of the affected row count.
149    row_cnt_format: Option<Format>,
150    notices: Vec<String>,
151    values_stream: Option<VS>,
152    callback: Option<BoxedCallback>,
153    row_desc: Vec<PgFieldDescriptor>,
154    status: ParameterStatus,
155}
156
157impl<VS> From<PgResponseBuilder<VS>> for PgResponse<VS> {
158    fn from(builder: PgResponseBuilder<VS>) -> Self {
159        Self {
160            stmt_type: builder.stmt_type,
161            row_cnt: builder.row_cnt,
162            row_cnt_format: builder.row_cnt_format,
163            notices: builder.notices,
164            values_stream: builder.values_stream,
165            callback: builder.callback,
166            row_desc: builder.row_desc,
167            status: builder.status,
168        }
169    }
170}
171
172impl<VS> PgResponseBuilder<VS> {
173    pub fn empty(stmt_type: StatementType) -> Self {
174        let row_cnt = if stmt_type.is_query() { None } else { Some(0) };
175        Self {
176            stmt_type,
177            row_cnt,
178            row_cnt_format: None,
179            notices: vec![],
180            values_stream: None,
181            callback: None,
182            row_desc: vec![],
183            status: Default::default(),
184        }
185    }
186
187    pub fn row_cnt(self, row_cnt: i32) -> Self {
188        Self {
189            row_cnt: Some(row_cnt),
190            ..self
191        }
192    }
193
194    pub fn row_cnt_opt(self, row_cnt: Option<i32>) -> Self {
195        Self { row_cnt, ..self }
196    }
197
198    pub fn row_cnt_format_opt(self, row_cnt_format: Option<Format>) -> Self {
199        Self {
200            row_cnt_format,
201            ..self
202        }
203    }
204
205    pub fn values(self, values_stream: VS, row_desc: Vec<PgFieldDescriptor>) -> Self {
206        Self {
207            values_stream: Some(values_stream),
208            row_desc,
209            ..self
210        }
211    }
212
213    pub fn callback(self, callback: impl Callback + 'static) -> Self {
214        Self {
215            callback: Some(callback.boxed()),
216            ..self
217        }
218    }
219
220    pub fn notice(self, notice: impl ToString) -> Self {
221        let mut notices = self.notices;
222        notices.push(notice.to_string());
223        Self { notices, ..self }
224    }
225
226    pub fn status(self, status: ParameterStatus) -> Self {
227        Self { status, ..self }
228    }
229}
230
231impl<VS> std::fmt::Debug for PgResponse<VS> {
232    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
233        f.debug_struct("PgResponse")
234            .field("stmt_type", &self.stmt_type)
235            .field("row_cnt", &self.row_cnt)
236            .field("notices", &self.notices)
237            .field("row_desc", &self.row_desc)
238            .finish()
239    }
240}
241
242impl StatementType {
243    pub fn infer_from_statement(stmt: &Statement) -> Result<Self, String> {
244        match stmt {
245            Statement::Query(_) => Ok(StatementType::SELECT),
246            Statement::Insert { returning, .. } => {
247                if returning.is_empty() {
248                    Ok(StatementType::INSERT)
249                } else {
250                    Ok(StatementType::INSERT_RETURNING)
251                }
252            }
253            Statement::Delete { returning, .. } => {
254                if returning.is_empty() {
255                    Ok(StatementType::DELETE)
256                } else {
257                    Ok(StatementType::DELETE_RETURNING)
258                }
259            }
260            Statement::Update { returning, .. } => {
261                if returning.is_empty() {
262                    Ok(StatementType::UPDATE)
263                } else {
264                    Ok(StatementType::UPDATE_RETURNING)
265                }
266            }
267            Statement::Copy { .. } => Ok(StatementType::COPY),
268            Statement::CreateTable { .. } => Ok(StatementType::CREATE_TABLE),
269            Statement::CreateIndex { .. } => Ok(StatementType::CREATE_INDEX),
270            Statement::CreateSchema { .. } => Ok(StatementType::CREATE_SCHEMA),
271            Statement::CreateSource { .. } => Ok(StatementType::CREATE_SOURCE),
272            Statement::CreateSink { .. } => Ok(StatementType::CREATE_SINK),
273            Statement::CreateFunction { .. } => Ok(StatementType::CREATE_FUNCTION),
274            Statement::CreateDatabase { .. } => Ok(StatementType::CREATE_DATABASE),
275            Statement::CreateUser { .. } => Ok(StatementType::CREATE_USER),
276            Statement::CreateView { materialized, .. } => {
277                if *materialized {
278                    Ok(StatementType::CREATE_MATERIALIZED_VIEW)
279                } else {
280                    Ok(StatementType::CREATE_VIEW)
281                }
282            }
283            Statement::AlterTable { .. } => Ok(StatementType::ALTER_TABLE),
284            Statement::AlterSystem { .. } => Ok(StatementType::ALTER_SYSTEM),
285            Statement::DropFunction { .. } => Ok(StatementType::DROP_FUNCTION),
286            Statement::Discard(..) => Ok(StatementType::DISCARD),
287            Statement::SetVariable { .. } => Ok(StatementType::SET_VARIABLE),
288            Statement::ShowVariable { .. } => Ok(StatementType::SHOW_VARIABLE),
289            Statement::StartTransaction { .. } => Ok(StatementType::START_TRANSACTION),
290            Statement::Begin { .. } => Ok(StatementType::BEGIN),
291            Statement::Abort => Ok(StatementType::ABORT),
292            Statement::Commit { .. } => Ok(StatementType::COMMIT),
293            Statement::Rollback { .. } => Ok(StatementType::ROLLBACK),
294            Statement::Grant { .. } => Ok(StatementType::GRANT_PRIVILEGE),
295            Statement::Revoke { .. } => Ok(StatementType::REVOKE_PRIVILEGE),
296            Statement::Describe { .. } => Ok(StatementType::DESCRIBE),
297            Statement::ShowCreateObject { .. } | Statement::ShowObjects { .. } => {
298                Ok(StatementType::SHOW_COMMAND)
299            }
300            Statement::Drop(stmt) => match stmt.object_type {
301                risingwave_sqlparser::ast::ObjectType::Table => Ok(StatementType::DROP_TABLE),
302                risingwave_sqlparser::ast::ObjectType::View => Ok(StatementType::DROP_VIEW),
303                risingwave_sqlparser::ast::ObjectType::MaterializedView => {
304                    Ok(StatementType::DROP_MATERIALIZED_VIEW)
305                }
306                risingwave_sqlparser::ast::ObjectType::Index => Ok(StatementType::DROP_INDEX),
307                risingwave_sqlparser::ast::ObjectType::Schema => Ok(StatementType::DROP_SCHEMA),
308                risingwave_sqlparser::ast::ObjectType::Source => Ok(StatementType::DROP_SOURCE),
309                risingwave_sqlparser::ast::ObjectType::Sink => Ok(StatementType::DROP_SINK),
310                risingwave_sqlparser::ast::ObjectType::Database => Ok(StatementType::DROP_DATABASE),
311                risingwave_sqlparser::ast::ObjectType::User => Ok(StatementType::DROP_USER),
312                risingwave_sqlparser::ast::ObjectType::Connection => {
313                    Ok(StatementType::DROP_CONNECTION)
314                }
315                risingwave_sqlparser::ast::ObjectType::Secret => Ok(StatementType::DROP_SECRET),
316                risingwave_sqlparser::ast::ObjectType::Subscription => {
317                    Ok(StatementType::DROP_SUBSCRIPTION)
318                }
319            },
320            Statement::Explain { .. } => Ok(StatementType::EXPLAIN),
321            Statement::DeclareCursor { .. } => Ok(StatementType::DECLARE_CURSOR),
322            Statement::FetchCursor { .. } => Ok(StatementType::FETCH_CURSOR),
323            Statement::CloseCursor { .. } => Ok(StatementType::CLOSE_CURSOR),
324            Statement::Flush => Ok(StatementType::FLUSH),
325            Statement::Wait => Ok(StatementType::WAIT),
326            Statement::Use { .. } => Ok(StatementType::USE),
327            _ => Err("unsupported statement type".to_owned()),
328        }
329    }
330
331    pub fn is_command(&self) -> bool {
332        matches!(
333            self,
334            StatementType::INSERT
335                | StatementType::DELETE
336                | StatementType::UPDATE
337                | StatementType::MOVE
338                | StatementType::COPY
339                | StatementType::FETCH
340                | StatementType::SELECT
341                | StatementType::INSERT_RETURNING
342                | StatementType::DELETE_RETURNING
343                | StatementType::UPDATE_RETURNING
344        )
345    }
346
347    pub fn is_dml(&self) -> bool {
348        matches!(
349            self,
350            StatementType::INSERT
351                | StatementType::DELETE
352                | StatementType::UPDATE
353                | StatementType::INSERT_RETURNING
354                | StatementType::DELETE_RETURNING
355                | StatementType::UPDATE_RETURNING
356        )
357    }
358
359    pub fn is_query(&self) -> bool {
360        matches!(
361            self,
362            StatementType::SELECT
363                | StatementType::EXPLAIN
364                | StatementType::SHOW_COMMAND
365                | StatementType::SHOW_VARIABLE
366                | StatementType::DESCRIBE
367                | StatementType::INSERT_RETURNING
368                | StatementType::DELETE_RETURNING
369                | StatementType::UPDATE_RETURNING
370                | StatementType::CANCEL_COMMAND
371                | StatementType::FETCH_CURSOR
372        )
373    }
374
375    pub fn is_returning(&self) -> bool {
376        matches!(
377            self,
378            StatementType::INSERT_RETURNING
379                | StatementType::DELETE_RETURNING
380                | StatementType::UPDATE_RETURNING
381        )
382    }
383}
384
385impl<VS> PgResponse<VS>
386where
387    VS: ValuesStream,
388{
389    pub fn builder(stmt_type: StatementType) -> PgResponseBuilder<VS> {
390        PgResponseBuilder::empty(stmt_type)
391    }
392
393    pub fn empty_result(stmt_type: StatementType) -> Self {
394        PgResponseBuilder::empty(stmt_type).into()
395    }
396
397    pub fn stmt_type(&self) -> StatementType {
398        self.stmt_type
399    }
400
401    pub fn notices(&self) -> &[String] {
402        &self.notices
403    }
404
405    pub fn status(&self) -> &ParameterStatus {
406        &self.status
407    }
408
409    pub fn affected_rows_cnt(&self) -> Option<i32> {
410        self.row_cnt
411    }
412
413    pub fn row_cnt_format(&self) -> Option<Format> {
414        self.row_cnt_format
415    }
416
417    pub fn is_query(&self) -> bool {
418        self.stmt_type.is_query()
419    }
420
421    pub fn is_empty(&self) -> bool {
422        self.stmt_type == StatementType::EMPTY
423    }
424
425    pub fn row_desc(&self) -> Vec<PgFieldDescriptor> {
426        self.row_desc.clone()
427    }
428
429    pub fn values_stream(&mut self) -> &mut VS {
430        self.values_stream.as_mut().expect("no values stream")
431    }
432
433    /// Run the callback if there is one.
434    ///
435    /// This should only be called after the values stream has been exhausted. Multiple calls to
436    /// this function will be no-ops.
437    pub async fn run_callback(&mut self) -> Result<(), PsqlError> {
438        // Check if the stream is exhausted.
439        if let Some(values_stream) = &mut self.values_stream {
440            assert!(values_stream.next().await.is_none());
441        }
442
443        if let Some(callback) = self.callback.take() {
444            callback.await.map_err(PsqlError::SimpleQueryError)?;
445        }
446        Ok(())
447    }
448}