1use std::fmt::Formatter;
16use std::pin::Pin;
17
18use futures::{Future, FutureExt, Stream, StreamExt};
19use risingwave_sqlparser::ast::Statement;
20
21use crate::error::PsqlError;
22use crate::pg_field_descriptor::PgFieldDescriptor;
23use crate::pg_protocol::ParameterStatus;
24use crate::pg_server::BoxedError;
25use crate::types::{Format, Row};
26
27pub type RowSet = Vec<Row>;
28pub type RowSetResult = Result<RowSet, BoxedError>;
29
30pub trait ValuesStream = Stream<Item = RowSetResult> + Unpin + Send;
31
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
34pub enum StatementType {
35 INSERT,
36 INSERT_RETURNING,
37 DELETE,
38 DELETE_RETURNING,
39 UPDATE,
40 UPDATE_RETURNING,
41 SELECT,
42 MOVE,
43 FETCH,
44 COPY,
45 EXPLAIN,
46 CLOSE_CURSOR,
47 CREATE_TABLE,
48 CREATE_MATERIALIZED_VIEW,
49 CREATE_VIEW,
50 CREATE_SOURCE,
51 CREATE_SINK,
52 CREATE_SUBSCRIPTION,
53 CREATE_DATABASE,
54 CREATE_SCHEMA,
55 CREATE_USER,
56 CREATE_INDEX,
57 CREATE_AGGREGATE,
58 CREATE_FUNCTION,
59 CREATE_CONNECTION,
60 CREATE_SECRET,
61 COMMENT,
62 DECLARE_CURSOR,
63 DESCRIBE,
64 GRANT_PRIVILEGE,
65 DISCARD,
66 DROP_TABLE,
67 DROP_MATERIALIZED_VIEW,
68 DROP_VIEW,
69 DROP_INDEX,
70 DROP_FUNCTION,
71 DROP_AGGREGATE,
72 DROP_SOURCE,
73 DROP_SINK,
74 DROP_SUBSCRIPTION,
75 DROP_SCHEMA,
76 DROP_DATABASE,
77 DROP_USER,
78 DROP_CONNECTION,
79 DROP_SECRET,
80 ALTER_DATABASE,
81 ALTER_DEFAULT_PRIVILEGES,
82 ALTER_SCHEMA,
83 ALTER_INDEX,
84 ALTER_VIEW,
85 ALTER_TABLE,
86 ALTER_MATERIALIZED_VIEW,
87 ALTER_SINK,
88 ALTER_SUBSCRIPTION,
89 ALTER_SOURCE,
90 ALTER_FUNCTION,
91 ALTER_CONNECTION,
92 ALTER_SYSTEM,
93 ALTER_SECRET,
94 REVOKE_PRIVILEGE,
95 ORDER_BY,
98 SET_VARIABLE,
99 SHOW_VARIABLE,
100 SHOW_COMMAND,
101 START_TRANSACTION,
102 UPDATE_USER,
103 ABORT,
104 FLUSH,
105 REFRESH_TABLE,
106 OTHER,
107 EMPTY,
109 BEGIN,
110 COMMIT,
111 ROLLBACK,
112 SET_TRANSACTION,
113 CANCEL_COMMAND,
114 FETCH_CURSOR,
115 WAIT,
116 KILL,
117 RECOVER,
118 USE,
119 PREPARE,
120 DEALLOCATE,
121 VACUUM,
122}
123
124impl std::fmt::Display for StatementType {
125 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
126 write!(f, "{:?}", self)
127 }
128}
129
130pub trait Callback = Future<Output = Result<(), BoxedError>> + Send;
131
132pub type BoxedCallback = Pin<Box<dyn Callback>>;
133
134pub struct PgResponse<VS> {
135 stmt_type: StatementType,
136 row_cnt: Option<i32>,
139 row_cnt_format: Option<Format>,
141 notices: Vec<String>,
142 values_stream: Option<VS>,
143 callback: Option<BoxedCallback>,
144 row_desc: Vec<PgFieldDescriptor>,
145 status: ParameterStatus,
146}
147
148pub struct PgResponseBuilder<VS> {
149 stmt_type: StatementType,
150 row_cnt: Option<i32>,
153 row_cnt_format: Option<Format>,
155 notices: Vec<String>,
156 values_stream: Option<VS>,
157 callback: Option<BoxedCallback>,
158 row_desc: Vec<PgFieldDescriptor>,
159 status: ParameterStatus,
160}
161
162impl<VS> From<PgResponseBuilder<VS>> for PgResponse<VS> {
163 fn from(builder: PgResponseBuilder<VS>) -> Self {
164 Self {
165 stmt_type: builder.stmt_type,
166 row_cnt: builder.row_cnt,
167 row_cnt_format: builder.row_cnt_format,
168 notices: builder.notices,
169 values_stream: builder.values_stream,
170 callback: builder.callback,
171 row_desc: builder.row_desc,
172 status: builder.status,
173 }
174 }
175}
176
177impl<VS> PgResponseBuilder<VS> {
178 pub fn empty(stmt_type: StatementType) -> Self {
179 let row_cnt = if stmt_type.is_query() { None } else { Some(0) };
180 Self {
181 stmt_type,
182 row_cnt,
183 row_cnt_format: None,
184 notices: vec![],
185 values_stream: None,
186 callback: None,
187 row_desc: vec![],
188 status: Default::default(),
189 }
190 }
191
192 pub fn row_cnt(self, row_cnt: i32) -> Self {
193 Self {
194 row_cnt: Some(row_cnt),
195 ..self
196 }
197 }
198
199 pub fn row_cnt_opt(self, row_cnt: Option<i32>) -> Self {
200 Self { row_cnt, ..self }
201 }
202
203 pub fn row_cnt_format_opt(self, row_cnt_format: Option<Format>) -> Self {
204 Self {
205 row_cnt_format,
206 ..self
207 }
208 }
209
210 pub fn values(self, values_stream: VS, row_desc: Vec<PgFieldDescriptor>) -> Self {
211 Self {
212 values_stream: Some(values_stream),
213 row_desc,
214 ..self
215 }
216 }
217
218 pub fn callback(self, callback: impl Callback + 'static) -> Self {
219 Self {
220 callback: Some(callback.boxed()),
221 ..self
222 }
223 }
224
225 pub fn notice(self, notice: impl ToString) -> Self {
226 let mut notices = self.notices;
227 notices.push(notice.to_string());
228 Self { notices, ..self }
229 }
230
231 pub fn status(self, status: ParameterStatus) -> Self {
232 Self { status, ..self }
233 }
234}
235
236impl<VS> std::fmt::Debug for PgResponse<VS> {
237 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
238 f.debug_struct("PgResponse")
239 .field("stmt_type", &self.stmt_type)
240 .field("row_cnt", &self.row_cnt)
241 .field("notices", &self.notices)
242 .field("row_desc", &self.row_desc)
243 .finish()
244 }
245}
246
247impl StatementType {
248 pub fn infer_from_statement(stmt: &Statement) -> Result<Self, String> {
249 match stmt {
250 Statement::Query(_) => Ok(StatementType::SELECT),
251 Statement::Insert { returning, .. } => {
252 if returning.is_empty() {
253 Ok(StatementType::INSERT)
254 } else {
255 Ok(StatementType::INSERT_RETURNING)
256 }
257 }
258 Statement::Delete { returning, .. } => {
259 if returning.is_empty() {
260 Ok(StatementType::DELETE)
261 } else {
262 Ok(StatementType::DELETE_RETURNING)
263 }
264 }
265 Statement::Update { returning, .. } => {
266 if returning.is_empty() {
267 Ok(StatementType::UPDATE)
268 } else {
269 Ok(StatementType::UPDATE_RETURNING)
270 }
271 }
272 Statement::Copy { .. } => Ok(StatementType::COPY),
273 Statement::CreateTable { .. } => Ok(StatementType::CREATE_TABLE),
274 Statement::CreateIndex { .. } => Ok(StatementType::CREATE_INDEX),
275 Statement::CreateSchema { .. } => Ok(StatementType::CREATE_SCHEMA),
276 Statement::CreateSource { .. } => Ok(StatementType::CREATE_SOURCE),
277 Statement::CreateSink { .. } => Ok(StatementType::CREATE_SINK),
278 Statement::CreateFunction { .. } => Ok(StatementType::CREATE_FUNCTION),
279 Statement::CreateDatabase { .. } => Ok(StatementType::CREATE_DATABASE),
280 Statement::CreateUser { .. } => Ok(StatementType::CREATE_USER),
281 Statement::CreateView { materialized, .. } => {
282 if *materialized {
283 Ok(StatementType::CREATE_MATERIALIZED_VIEW)
284 } else {
285 Ok(StatementType::CREATE_VIEW)
286 }
287 }
288 Statement::AlterTable { .. } => Ok(StatementType::ALTER_TABLE),
289 Statement::AlterSystem { .. } => Ok(StatementType::ALTER_SYSTEM),
290 Statement::DropFunction { .. } => Ok(StatementType::DROP_FUNCTION),
291 Statement::Discard(..) => Ok(StatementType::DISCARD),
292 Statement::SetVariable { .. } => Ok(StatementType::SET_VARIABLE),
293 Statement::ShowVariable { .. } => Ok(StatementType::SHOW_VARIABLE),
294 Statement::StartTransaction { .. } => Ok(StatementType::START_TRANSACTION),
295 Statement::Begin { .. } => Ok(StatementType::BEGIN),
296 Statement::Abort => Ok(StatementType::ABORT),
297 Statement::Commit { .. } => Ok(StatementType::COMMIT),
298 Statement::Rollback { .. } => Ok(StatementType::ROLLBACK),
299 Statement::Grant { .. } => Ok(StatementType::GRANT_PRIVILEGE),
300 Statement::Revoke { .. } => Ok(StatementType::REVOKE_PRIVILEGE),
301 Statement::Describe { .. } => Ok(StatementType::DESCRIBE),
302 Statement::ShowCreateObject { .. } | Statement::ShowObjects { .. } => {
303 Ok(StatementType::SHOW_COMMAND)
304 }
305 Statement::Drop(stmt) => match stmt.object_type {
306 risingwave_sqlparser::ast::ObjectType::Table => Ok(StatementType::DROP_TABLE),
307 risingwave_sqlparser::ast::ObjectType::View => Ok(StatementType::DROP_VIEW),
308 risingwave_sqlparser::ast::ObjectType::MaterializedView => {
309 Ok(StatementType::DROP_MATERIALIZED_VIEW)
310 }
311 risingwave_sqlparser::ast::ObjectType::Index => Ok(StatementType::DROP_INDEX),
312 risingwave_sqlparser::ast::ObjectType::Schema => Ok(StatementType::DROP_SCHEMA),
313 risingwave_sqlparser::ast::ObjectType::Source => Ok(StatementType::DROP_SOURCE),
314 risingwave_sqlparser::ast::ObjectType::Sink => Ok(StatementType::DROP_SINK),
315 risingwave_sqlparser::ast::ObjectType::Database => Ok(StatementType::DROP_DATABASE),
316 risingwave_sqlparser::ast::ObjectType::User => Ok(StatementType::DROP_USER),
317 risingwave_sqlparser::ast::ObjectType::Connection => {
318 Ok(StatementType::DROP_CONNECTION)
319 }
320 risingwave_sqlparser::ast::ObjectType::Secret => Ok(StatementType::DROP_SECRET),
321 risingwave_sqlparser::ast::ObjectType::Subscription => {
322 Ok(StatementType::DROP_SUBSCRIPTION)
323 }
324 },
325 Statement::Explain { .. } => Ok(StatementType::EXPLAIN),
326 Statement::DeclareCursor { .. } => Ok(StatementType::DECLARE_CURSOR),
327 Statement::FetchCursor { .. } => Ok(StatementType::FETCH_CURSOR),
328 Statement::CloseCursor { .. } => Ok(StatementType::CLOSE_CURSOR),
329 Statement::Flush => Ok(StatementType::FLUSH),
330 Statement::Wait => Ok(StatementType::WAIT),
331 Statement::Use { .. } => Ok(StatementType::USE),
332 Statement::Vacuum { .. } => Ok(StatementType::VACUUM),
333 _ => Err("unsupported statement type".to_owned()),
334 }
335 }
336
337 pub fn is_command(&self) -> bool {
338 matches!(
339 self,
340 StatementType::INSERT
341 | StatementType::DELETE
342 | StatementType::UPDATE
343 | StatementType::MOVE
344 | StatementType::COPY
345 | StatementType::FETCH
346 | StatementType::SELECT
347 | StatementType::INSERT_RETURNING
348 | StatementType::DELETE_RETURNING
349 | StatementType::UPDATE_RETURNING
350 )
351 }
352
353 pub fn is_dml(&self) -> bool {
354 matches!(
355 self,
356 StatementType::INSERT
357 | StatementType::DELETE
358 | StatementType::UPDATE
359 | StatementType::INSERT_RETURNING
360 | StatementType::DELETE_RETURNING
361 | StatementType::UPDATE_RETURNING
362 )
363 }
364
365 pub fn is_query(&self) -> bool {
366 matches!(
367 self,
368 StatementType::SELECT
369 | StatementType::EXPLAIN
370 | StatementType::SHOW_COMMAND
371 | StatementType::SHOW_VARIABLE
372 | StatementType::DESCRIBE
373 | StatementType::INSERT_RETURNING
374 | StatementType::DELETE_RETURNING
375 | StatementType::UPDATE_RETURNING
376 | StatementType::CANCEL_COMMAND
377 | StatementType::FETCH_CURSOR
378 )
379 }
380
381 pub fn is_returning(&self) -> bool {
382 matches!(
383 self,
384 StatementType::INSERT_RETURNING
385 | StatementType::DELETE_RETURNING
386 | StatementType::UPDATE_RETURNING
387 )
388 }
389}
390
391impl<VS> PgResponse<VS>
392where
393 VS: ValuesStream,
394{
395 pub fn builder(stmt_type: StatementType) -> PgResponseBuilder<VS> {
396 PgResponseBuilder::empty(stmt_type)
397 }
398
399 pub fn empty_result(stmt_type: StatementType) -> Self {
400 PgResponseBuilder::empty(stmt_type).into()
401 }
402
403 pub fn stmt_type(&self) -> StatementType {
404 self.stmt_type
405 }
406
407 pub fn notices(&self) -> &[String] {
408 &self.notices
409 }
410
411 pub fn status(&self) -> &ParameterStatus {
412 &self.status
413 }
414
415 pub fn affected_rows_cnt(&self) -> Option<i32> {
416 self.row_cnt
417 }
418
419 pub fn row_cnt_format(&self) -> Option<Format> {
420 self.row_cnt_format
421 }
422
423 pub fn is_query(&self) -> bool {
424 self.stmt_type.is_query()
425 }
426
427 pub fn is_empty(&self) -> bool {
428 self.stmt_type == StatementType::EMPTY
429 }
430
431 pub fn row_desc(&self) -> Vec<PgFieldDescriptor> {
432 self.row_desc.clone()
433 }
434
435 pub fn values_stream(&mut self) -> &mut VS {
436 self.values_stream.as_mut().expect("no values stream")
437 }
438
439 pub async fn run_callback(&mut self) -> Result<(), PsqlError> {
444 if let Some(values_stream) = &mut self.values_stream {
446 assert!(values_stream.next().await.is_none());
447 }
448
449 if let Some(callback) = self.callback.take() {
450 callback.await.map_err(PsqlError::SimpleQueryError)?;
451 }
452 Ok(())
453 }
454}