1use std::fmt::Formatter;
16use std::pin::Pin;
17
18use futures::{Future, FutureExt, Stream, StreamExt};
19use risingwave_sqlparser::ast::Statement;
20
21use crate::error::PsqlError;
22use crate::pg_field_descriptor::PgFieldDescriptor;
23use crate::pg_protocol::ParameterStatus;
24use crate::pg_server::BoxedError;
25use crate::types::{Format, Row};
26
27pub type RowSet = Vec<Row>;
28pub type RowSetResult = Result<RowSet, BoxedError>;
29
30pub trait ValuesStream = Stream<Item = RowSetResult> + Unpin + Send;
31
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
34pub enum StatementType {
35 INSERT,
36 INSERT_RETURNING,
37 DELETE,
38 DELETE_RETURNING,
39 UPDATE,
40 UPDATE_RETURNING,
41 SELECT,
42 MOVE,
43 FETCH,
44 COPY,
45 EXPLAIN,
46 CLOSE_CURSOR,
47 CREATE_TABLE,
48 CREATE_MATERIALIZED_VIEW,
49 CREATE_VIEW,
50 CREATE_SOURCE,
51 CREATE_SINK,
52 CREATE_SUBSCRIPTION,
53 CREATE_DATABASE,
54 CREATE_SCHEMA,
55 CREATE_USER,
56 CREATE_INDEX,
57 CREATE_AGGREGATE,
58 CREATE_FUNCTION,
59 CREATE_CONNECTION,
60 CREATE_SECRET,
61 COMMENT,
62 DECLARE_CURSOR,
63 DESCRIBE,
64 GRANT_PRIVILEGE,
65 DISCARD,
66 DROP_TABLE,
67 DROP_MATERIALIZED_VIEW,
68 DROP_VIEW,
69 DROP_INDEX,
70 DROP_FUNCTION,
71 DROP_AGGREGATE,
72 DROP_SOURCE,
73 DROP_SINK,
74 DROP_SUBSCRIPTION,
75 DROP_SCHEMA,
76 DROP_DATABASE,
77 DROP_USER,
78 DROP_CONNECTION,
79 DROP_SECRET,
80 ALTER_DATABASE,
81 ALTER_DEFAULT_PRIVILEGES,
82 ALTER_SCHEMA,
83 ALTER_INDEX,
84 ALTER_VIEW,
85 ALTER_TABLE,
86 ALTER_MATERIALIZED_VIEW,
87 ALTER_SINK,
88 ALTER_SUBSCRIPTION,
89 ALTER_SOURCE,
90 ALTER_FUNCTION,
91 ALTER_CONNECTION,
92 ALTER_SYSTEM,
93 ALTER_SECRET,
94 ALTER_FRAGMENT,
95 REVOKE_PRIVILEGE,
96 ORDER_BY,
99 SET_VARIABLE,
100 SHOW_VARIABLE,
101 SHOW_COMMAND,
102 START_TRANSACTION,
103 UPDATE_USER,
104 ABORT,
105 FLUSH,
106 REFRESH_TABLE,
107 OTHER,
108 EMPTY,
110 BEGIN,
111 COMMIT,
112 ROLLBACK,
113 SET_TRANSACTION,
114 CANCEL_COMMAND,
115 FETCH_CURSOR,
116 WAIT,
117 KILL,
118 RECOVER,
119 USE,
120 PREPARE,
121 DEALLOCATE,
122 VACUUM,
123}
124
125impl std::fmt::Display for StatementType {
126 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
127 write!(f, "{:?}", self)
128 }
129}
130
131pub trait Callback = Future<Output = Result<(), BoxedError>> + Send;
132
133pub type BoxedCallback = Pin<Box<dyn Callback>>;
134
135pub struct PgResponse<VS> {
136 stmt_type: StatementType,
137 is_copy_query_to_stdout: bool,
138
139 row_cnt: Option<i32>,
142 row_cnt_format: Option<Format>,
144 notices: Vec<String>,
145 values_stream: Option<VS>,
146 callback: Option<BoxedCallback>,
147 row_desc: Vec<PgFieldDescriptor>,
148 status: ParameterStatus,
149}
150
151pub struct PgResponseBuilder<VS> {
152 stmt_type: StatementType,
153 row_cnt: Option<i32>,
156 row_cnt_format: Option<Format>,
158 notices: Vec<String>,
159 values_stream: Option<VS>,
160 callback: Option<BoxedCallback>,
161 row_desc: Vec<PgFieldDescriptor>,
162 status: ParameterStatus,
163}
164
165impl<VS> From<PgResponseBuilder<VS>> for PgResponse<VS> {
166 fn from(builder: PgResponseBuilder<VS>) -> Self {
167 Self {
168 stmt_type: builder.stmt_type,
169 is_copy_query_to_stdout: false, row_cnt: builder.row_cnt,
171 row_cnt_format: builder.row_cnt_format,
172 notices: builder.notices,
173 values_stream: builder.values_stream,
174 callback: builder.callback,
175 row_desc: builder.row_desc,
176 status: builder.status,
177 }
178 }
179}
180
181impl<VS> PgResponseBuilder<VS> {
182 pub fn empty(stmt_type: StatementType) -> Self {
183 let row_cnt = if stmt_type.is_query() { None } else { Some(0) };
184 Self {
185 stmt_type,
186 row_cnt,
187 row_cnt_format: None,
188 notices: vec![],
189 values_stream: None,
190 callback: None,
191 row_desc: vec![],
192 status: Default::default(),
193 }
194 }
195
196 pub fn row_cnt(self, row_cnt: i32) -> Self {
197 Self {
198 row_cnt: Some(row_cnt),
199 ..self
200 }
201 }
202
203 pub fn row_cnt_opt(self, row_cnt: Option<i32>) -> Self {
204 Self { row_cnt, ..self }
205 }
206
207 pub fn row_cnt_format_opt(self, row_cnt_format: Option<Format>) -> Self {
208 Self {
209 row_cnt_format,
210 ..self
211 }
212 }
213
214 pub fn values(self, values_stream: VS, row_desc: Vec<PgFieldDescriptor>) -> Self {
215 Self {
216 values_stream: Some(values_stream),
217 row_desc,
218 ..self
219 }
220 }
221
222 pub fn callback(self, callback: impl Callback + 'static) -> Self {
223 Self {
224 callback: Some(callback.boxed()),
225 ..self
226 }
227 }
228
229 pub fn notice(self, notice: impl ToString) -> Self {
230 let mut notices = self.notices;
231 notices.push(notice.to_string());
232 Self { notices, ..self }
233 }
234
235 pub fn status(self, status: ParameterStatus) -> Self {
236 Self { status, ..self }
237 }
238}
239
240impl<VS> std::fmt::Debug for PgResponse<VS> {
241 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
242 f.debug_struct("PgResponse")
243 .field("stmt_type", &self.stmt_type)
244 .field("row_cnt", &self.row_cnt)
245 .field("notices", &self.notices)
246 .field("row_desc", &self.row_desc)
247 .finish()
248 }
249}
250
251impl StatementType {
252 pub fn infer_from_statement(stmt: &Statement) -> Result<Self, String> {
253 match stmt {
254 Statement::Query(_) => Ok(StatementType::SELECT),
255 Statement::Insert { returning, .. } => {
256 if returning.is_empty() {
257 Ok(StatementType::INSERT)
258 } else {
259 Ok(StatementType::INSERT_RETURNING)
260 }
261 }
262 Statement::Delete { returning, .. } => {
263 if returning.is_empty() {
264 Ok(StatementType::DELETE)
265 } else {
266 Ok(StatementType::DELETE_RETURNING)
267 }
268 }
269 Statement::Update { returning, .. } => {
270 if returning.is_empty() {
271 Ok(StatementType::UPDATE)
272 } else {
273 Ok(StatementType::UPDATE_RETURNING)
274 }
275 }
276 Statement::Copy { .. } => Ok(StatementType::COPY),
277 Statement::CreateTable { .. } => Ok(StatementType::CREATE_TABLE),
278 Statement::CreateIndex { .. } => Ok(StatementType::CREATE_INDEX),
279 Statement::CreateSchema { .. } => Ok(StatementType::CREATE_SCHEMA),
280 Statement::CreateSource { .. } => Ok(StatementType::CREATE_SOURCE),
281 Statement::CreateSink { .. } => Ok(StatementType::CREATE_SINK),
282 Statement::CreateFunction { .. } => Ok(StatementType::CREATE_FUNCTION),
283 Statement::CreateDatabase { .. } => Ok(StatementType::CREATE_DATABASE),
284 Statement::CreateUser { .. } => Ok(StatementType::CREATE_USER),
285 Statement::CreateView { materialized, .. } => {
286 if *materialized {
287 Ok(StatementType::CREATE_MATERIALIZED_VIEW)
288 } else {
289 Ok(StatementType::CREATE_VIEW)
290 }
291 }
292 Statement::AlterTable { .. } => Ok(StatementType::ALTER_TABLE),
293 Statement::AlterSystem { .. } => Ok(StatementType::ALTER_SYSTEM),
294 Statement::AlterFragment { .. } => Ok(StatementType::ALTER_FRAGMENT),
295 Statement::DropFunction { .. } => Ok(StatementType::DROP_FUNCTION),
296 Statement::Discard(..) => Ok(StatementType::DISCARD),
297 Statement::SetVariable { .. } => Ok(StatementType::SET_VARIABLE),
298 Statement::ShowVariable { .. } => Ok(StatementType::SHOW_VARIABLE),
299 Statement::StartTransaction { .. } => Ok(StatementType::START_TRANSACTION),
300 Statement::Begin { .. } => Ok(StatementType::BEGIN),
301 Statement::Abort => Ok(StatementType::ABORT),
302 Statement::Commit { .. } => Ok(StatementType::COMMIT),
303 Statement::Rollback { .. } => Ok(StatementType::ROLLBACK),
304 Statement::Grant { .. } => Ok(StatementType::GRANT_PRIVILEGE),
305 Statement::Revoke { .. } => Ok(StatementType::REVOKE_PRIVILEGE),
306 Statement::Describe { .. } => Ok(StatementType::DESCRIBE),
307 Statement::ShowCreateObject { .. } | Statement::ShowObjects { .. } => {
308 Ok(StatementType::SHOW_COMMAND)
309 }
310 Statement::Drop(stmt) => match stmt.object_type {
311 risingwave_sqlparser::ast::ObjectType::Table => Ok(StatementType::DROP_TABLE),
312 risingwave_sqlparser::ast::ObjectType::View => Ok(StatementType::DROP_VIEW),
313 risingwave_sqlparser::ast::ObjectType::MaterializedView => {
314 Ok(StatementType::DROP_MATERIALIZED_VIEW)
315 }
316 risingwave_sqlparser::ast::ObjectType::Index => Ok(StatementType::DROP_INDEX),
317 risingwave_sqlparser::ast::ObjectType::Schema => Ok(StatementType::DROP_SCHEMA),
318 risingwave_sqlparser::ast::ObjectType::Source => Ok(StatementType::DROP_SOURCE),
319 risingwave_sqlparser::ast::ObjectType::Sink => Ok(StatementType::DROP_SINK),
320 risingwave_sqlparser::ast::ObjectType::Database => Ok(StatementType::DROP_DATABASE),
321 risingwave_sqlparser::ast::ObjectType::User => Ok(StatementType::DROP_USER),
322 risingwave_sqlparser::ast::ObjectType::Connection => {
323 Ok(StatementType::DROP_CONNECTION)
324 }
325 risingwave_sqlparser::ast::ObjectType::Secret => Ok(StatementType::DROP_SECRET),
326 risingwave_sqlparser::ast::ObjectType::Subscription => {
327 Ok(StatementType::DROP_SUBSCRIPTION)
328 }
329 },
330 Statement::Explain { .. } => Ok(StatementType::EXPLAIN),
331 Statement::DeclareCursor { .. } => Ok(StatementType::DECLARE_CURSOR),
332 Statement::FetchCursor { .. } => Ok(StatementType::FETCH_CURSOR),
333 Statement::CloseCursor { .. } => Ok(StatementType::CLOSE_CURSOR),
334 Statement::Flush => Ok(StatementType::FLUSH),
335 Statement::Wait => Ok(StatementType::WAIT),
336 Statement::Use { .. } => Ok(StatementType::USE),
337 Statement::Vacuum { .. } => Ok(StatementType::VACUUM),
338 _ => Err("unsupported statement type".to_owned()),
339 }
340 }
341
342 pub fn is_command(&self) -> bool {
343 matches!(
344 self,
345 StatementType::INSERT
346 | StatementType::DELETE
347 | StatementType::UPDATE
348 | StatementType::MOVE
349 | StatementType::COPY
350 | StatementType::FETCH
351 | StatementType::SELECT
352 | StatementType::INSERT_RETURNING
353 | StatementType::DELETE_RETURNING
354 | StatementType::UPDATE_RETURNING
355 )
356 }
357
358 pub fn is_dml(&self) -> bool {
359 matches!(
360 self,
361 StatementType::INSERT
362 | StatementType::DELETE
363 | StatementType::UPDATE
364 | StatementType::INSERT_RETURNING
365 | StatementType::DELETE_RETURNING
366 | StatementType::UPDATE_RETURNING
367 )
368 }
369
370 pub fn is_query(&self) -> bool {
371 matches!(
372 self,
373 StatementType::SELECT
374 | StatementType::EXPLAIN
375 | StatementType::SHOW_COMMAND
376 | StatementType::SHOW_VARIABLE
377 | StatementType::DESCRIBE
378 | StatementType::INSERT_RETURNING
379 | StatementType::DELETE_RETURNING
380 | StatementType::UPDATE_RETURNING
381 | StatementType::CANCEL_COMMAND
382 | StatementType::FETCH_CURSOR
383 )
384 }
385
386 pub fn is_returning(&self) -> bool {
387 matches!(
388 self,
389 StatementType::INSERT_RETURNING
390 | StatementType::DELETE_RETURNING
391 | StatementType::UPDATE_RETURNING
392 )
393 }
394}
395
396impl<VS> PgResponse<VS>
397where
398 VS: ValuesStream,
399{
400 pub fn builder(stmt_type: StatementType) -> PgResponseBuilder<VS> {
401 PgResponseBuilder::empty(stmt_type)
402 }
403
404 pub fn empty_result(stmt_type: StatementType) -> Self {
405 PgResponseBuilder::empty(stmt_type).into()
406 }
407
408 pub fn into_copy_query_to_stdout(mut self) -> Self {
409 self.is_copy_query_to_stdout = true;
410 self.stmt_type = StatementType::COPY;
411 self
412 }
413
414 pub fn stmt_type(&self) -> StatementType {
415 self.stmt_type
416 }
417
418 pub fn notices(&self) -> &[String] {
419 &self.notices
420 }
421
422 pub fn status(&self) -> &ParameterStatus {
423 &self.status
424 }
425
426 pub fn affected_rows_cnt(&self) -> Option<i32> {
427 self.row_cnt
428 }
429
430 pub fn row_cnt_format(&self) -> Option<Format> {
431 self.row_cnt_format
432 }
433
434 pub fn is_query(&self) -> bool {
435 self.stmt_type.is_query()
436 }
437
438 pub fn is_empty(&self) -> bool {
439 self.stmt_type == StatementType::EMPTY
440 }
441
442 pub fn is_copy_query_to_stdout(&self) -> bool {
443 self.is_copy_query_to_stdout
444 }
445
446 pub fn row_desc(&self) -> &[PgFieldDescriptor] {
447 &self.row_desc
448 }
449
450 pub fn values_stream(&mut self) -> &mut VS {
451 self.values_stream.as_mut().expect("no values stream")
452 }
453
454 pub async fn run_callback(&mut self) -> Result<(), PsqlError> {
459 if let Some(values_stream) = &mut self.values_stream {
461 assert!(values_stream.next().await.is_none());
462 }
463
464 if let Some(callback) = self.callback.take() {
465 callback.await.map_err(PsqlError::SimpleQueryError)?;
466 }
467 Ok(())
468 }
469}