1use std::fmt::Formatter;
16use std::pin::Pin;
17
18use futures::{Future, FutureExt, Stream, StreamExt};
19use risingwave_sqlparser::ast::Statement;
20
21use crate::error::PsqlError;
22use crate::pg_field_descriptor::PgFieldDescriptor;
23use crate::pg_protocol::ParameterStatus;
24use crate::pg_server::BoxedError;
25use crate::types::{Format, Row};
26
27pub type RowSet = Vec<Row>;
28pub type RowSetResult = Result<RowSet, BoxedError>;
29
30pub trait ValuesStream = Stream<Item = RowSetResult> + Unpin + Send;
31
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
34pub enum StatementType {
35 INSERT,
36 INSERT_RETURNING,
37 DELETE,
38 DELETE_RETURNING,
39 UPDATE,
40 UPDATE_RETURNING,
41 SELECT,
42 MOVE,
43 FETCH,
44 COPY,
45 EXPLAIN,
46 CLOSE_CURSOR,
47 CREATE_TABLE,
48 CREATE_MATERIALIZED_VIEW,
49 CREATE_VIEW,
50 CREATE_SOURCE,
51 CREATE_SINK,
52 CREATE_SUBSCRIPTION,
53 CREATE_DATABASE,
54 CREATE_SCHEMA,
55 CREATE_USER,
56 CREATE_INDEX,
57 CREATE_AGGREGATE,
58 CREATE_FUNCTION,
59 CREATE_CONNECTION,
60 CREATE_SECRET,
61 COMMENT,
62 DECLARE_CURSOR,
63 DESCRIBE,
64 GRANT_PRIVILEGE,
65 DISCARD,
66 DROP_TABLE,
67 DROP_MATERIALIZED_VIEW,
68 DROP_VIEW,
69 DROP_INDEX,
70 DROP_FUNCTION,
71 DROP_AGGREGATE,
72 DROP_SOURCE,
73 DROP_SINK,
74 DROP_SUBSCRIPTION,
75 DROP_SCHEMA,
76 DROP_DATABASE,
77 DROP_USER,
78 DROP_CONNECTION,
79 DROP_SECRET,
80 ALTER_DATABASE,
81 ALTER_SCHEMA,
82 ALTER_INDEX,
83 ALTER_VIEW,
84 ALTER_TABLE,
85 ALTER_MATERIALIZED_VIEW,
86 ALTER_SINK,
87 ALTER_SUBSCRIPTION,
88 ALTER_SOURCE,
89 ALTER_FUNCTION,
90 ALTER_CONNECTION,
91 ALTER_SYSTEM,
92 ALTER_SECRET,
93 REVOKE_PRIVILEGE,
94 ORDER_BY,
97 SET_VARIABLE,
98 SHOW_VARIABLE,
99 SHOW_COMMAND,
100 START_TRANSACTION,
101 UPDATE_USER,
102 ABORT,
103 FLUSH,
104 OTHER,
105 EMPTY,
107 BEGIN,
108 COMMIT,
109 ROLLBACK,
110 SET_TRANSACTION,
111 CANCEL_COMMAND,
112 FETCH_CURSOR,
113 WAIT,
114 KILL,
115 RECOVER,
116 USE,
117}
118
119impl std::fmt::Display for StatementType {
120 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
121 write!(f, "{:?}", self)
122 }
123}
124
125pub trait Callback = Future<Output = Result<(), BoxedError>> + Send;
126
127pub type BoxedCallback = Pin<Box<dyn Callback>>;
128
129pub struct PgResponse<VS> {
130 stmt_type: StatementType,
131 row_cnt: Option<i32>,
134 row_cnt_format: Option<Format>,
136 notices: Vec<String>,
137 values_stream: Option<VS>,
138 callback: Option<BoxedCallback>,
139 row_desc: Vec<PgFieldDescriptor>,
140 status: ParameterStatus,
141}
142
143pub struct PgResponseBuilder<VS> {
144 stmt_type: StatementType,
145 row_cnt: Option<i32>,
148 row_cnt_format: Option<Format>,
150 notices: Vec<String>,
151 values_stream: Option<VS>,
152 callback: Option<BoxedCallback>,
153 row_desc: Vec<PgFieldDescriptor>,
154 status: ParameterStatus,
155}
156
157impl<VS> From<PgResponseBuilder<VS>> for PgResponse<VS> {
158 fn from(builder: PgResponseBuilder<VS>) -> Self {
159 Self {
160 stmt_type: builder.stmt_type,
161 row_cnt: builder.row_cnt,
162 row_cnt_format: builder.row_cnt_format,
163 notices: builder.notices,
164 values_stream: builder.values_stream,
165 callback: builder.callback,
166 row_desc: builder.row_desc,
167 status: builder.status,
168 }
169 }
170}
171
172impl<VS> PgResponseBuilder<VS> {
173 pub fn empty(stmt_type: StatementType) -> Self {
174 let row_cnt = if stmt_type.is_query() { None } else { Some(0) };
175 Self {
176 stmt_type,
177 row_cnt,
178 row_cnt_format: None,
179 notices: vec![],
180 values_stream: None,
181 callback: None,
182 row_desc: vec![],
183 status: Default::default(),
184 }
185 }
186
187 pub fn row_cnt(self, row_cnt: i32) -> Self {
188 Self {
189 row_cnt: Some(row_cnt),
190 ..self
191 }
192 }
193
194 pub fn row_cnt_opt(self, row_cnt: Option<i32>) -> Self {
195 Self { row_cnt, ..self }
196 }
197
198 pub fn row_cnt_format_opt(self, row_cnt_format: Option<Format>) -> Self {
199 Self {
200 row_cnt_format,
201 ..self
202 }
203 }
204
205 pub fn values(self, values_stream: VS, row_desc: Vec<PgFieldDescriptor>) -> Self {
206 Self {
207 values_stream: Some(values_stream),
208 row_desc,
209 ..self
210 }
211 }
212
213 pub fn callback(self, callback: impl Callback + 'static) -> Self {
214 Self {
215 callback: Some(callback.boxed()),
216 ..self
217 }
218 }
219
220 pub fn notice(self, notice: impl ToString) -> Self {
221 let mut notices = self.notices;
222 notices.push(notice.to_string());
223 Self { notices, ..self }
224 }
225
226 pub fn status(self, status: ParameterStatus) -> Self {
227 Self { status, ..self }
228 }
229}
230
231impl<VS> std::fmt::Debug for PgResponse<VS> {
232 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
233 f.debug_struct("PgResponse")
234 .field("stmt_type", &self.stmt_type)
235 .field("row_cnt", &self.row_cnt)
236 .field("notices", &self.notices)
237 .field("row_desc", &self.row_desc)
238 .finish()
239 }
240}
241
242impl StatementType {
243 pub fn infer_from_statement(stmt: &Statement) -> Result<Self, String> {
244 match stmt {
245 Statement::Query(_) => Ok(StatementType::SELECT),
246 Statement::Insert { returning, .. } => {
247 if returning.is_empty() {
248 Ok(StatementType::INSERT)
249 } else {
250 Ok(StatementType::INSERT_RETURNING)
251 }
252 }
253 Statement::Delete { returning, .. } => {
254 if returning.is_empty() {
255 Ok(StatementType::DELETE)
256 } else {
257 Ok(StatementType::DELETE_RETURNING)
258 }
259 }
260 Statement::Update { returning, .. } => {
261 if returning.is_empty() {
262 Ok(StatementType::UPDATE)
263 } else {
264 Ok(StatementType::UPDATE_RETURNING)
265 }
266 }
267 Statement::Copy { .. } => Ok(StatementType::COPY),
268 Statement::CreateTable { .. } => Ok(StatementType::CREATE_TABLE),
269 Statement::CreateIndex { .. } => Ok(StatementType::CREATE_INDEX),
270 Statement::CreateSchema { .. } => Ok(StatementType::CREATE_SCHEMA),
271 Statement::CreateSource { .. } => Ok(StatementType::CREATE_SOURCE),
272 Statement::CreateSink { .. } => Ok(StatementType::CREATE_SINK),
273 Statement::CreateFunction { .. } => Ok(StatementType::CREATE_FUNCTION),
274 Statement::CreateDatabase { .. } => Ok(StatementType::CREATE_DATABASE),
275 Statement::CreateUser { .. } => Ok(StatementType::CREATE_USER),
276 Statement::CreateView { materialized, .. } => {
277 if *materialized {
278 Ok(StatementType::CREATE_MATERIALIZED_VIEW)
279 } else {
280 Ok(StatementType::CREATE_VIEW)
281 }
282 }
283 Statement::AlterTable { .. } => Ok(StatementType::ALTER_TABLE),
284 Statement::AlterSystem { .. } => Ok(StatementType::ALTER_SYSTEM),
285 Statement::DropFunction { .. } => Ok(StatementType::DROP_FUNCTION),
286 Statement::Discard(..) => Ok(StatementType::DISCARD),
287 Statement::SetVariable { .. } => Ok(StatementType::SET_VARIABLE),
288 Statement::ShowVariable { .. } => Ok(StatementType::SHOW_VARIABLE),
289 Statement::StartTransaction { .. } => Ok(StatementType::START_TRANSACTION),
290 Statement::Begin { .. } => Ok(StatementType::BEGIN),
291 Statement::Abort => Ok(StatementType::ABORT),
292 Statement::Commit { .. } => Ok(StatementType::COMMIT),
293 Statement::Rollback { .. } => Ok(StatementType::ROLLBACK),
294 Statement::Grant { .. } => Ok(StatementType::GRANT_PRIVILEGE),
295 Statement::Revoke { .. } => Ok(StatementType::REVOKE_PRIVILEGE),
296 Statement::Describe { .. } => Ok(StatementType::DESCRIBE),
297 Statement::ShowCreateObject { .. } | Statement::ShowObjects { .. } => {
298 Ok(StatementType::SHOW_COMMAND)
299 }
300 Statement::Drop(stmt) => match stmt.object_type {
301 risingwave_sqlparser::ast::ObjectType::Table => Ok(StatementType::DROP_TABLE),
302 risingwave_sqlparser::ast::ObjectType::View => Ok(StatementType::DROP_VIEW),
303 risingwave_sqlparser::ast::ObjectType::MaterializedView => {
304 Ok(StatementType::DROP_MATERIALIZED_VIEW)
305 }
306 risingwave_sqlparser::ast::ObjectType::Index => Ok(StatementType::DROP_INDEX),
307 risingwave_sqlparser::ast::ObjectType::Schema => Ok(StatementType::DROP_SCHEMA),
308 risingwave_sqlparser::ast::ObjectType::Source => Ok(StatementType::DROP_SOURCE),
309 risingwave_sqlparser::ast::ObjectType::Sink => Ok(StatementType::DROP_SINK),
310 risingwave_sqlparser::ast::ObjectType::Database => Ok(StatementType::DROP_DATABASE),
311 risingwave_sqlparser::ast::ObjectType::User => Ok(StatementType::DROP_USER),
312 risingwave_sqlparser::ast::ObjectType::Connection => {
313 Ok(StatementType::DROP_CONNECTION)
314 }
315 risingwave_sqlparser::ast::ObjectType::Secret => Ok(StatementType::DROP_SECRET),
316 risingwave_sqlparser::ast::ObjectType::Subscription => {
317 Ok(StatementType::DROP_SUBSCRIPTION)
318 }
319 },
320 Statement::Explain { .. } => Ok(StatementType::EXPLAIN),
321 Statement::DeclareCursor { .. } => Ok(StatementType::DECLARE_CURSOR),
322 Statement::FetchCursor { .. } => Ok(StatementType::FETCH_CURSOR),
323 Statement::CloseCursor { .. } => Ok(StatementType::CLOSE_CURSOR),
324 Statement::Flush => Ok(StatementType::FLUSH),
325 Statement::Wait => Ok(StatementType::WAIT),
326 Statement::Use { .. } => Ok(StatementType::USE),
327 _ => Err("unsupported statement type".to_owned()),
328 }
329 }
330
331 pub fn is_command(&self) -> bool {
332 matches!(
333 self,
334 StatementType::INSERT
335 | StatementType::DELETE
336 | StatementType::UPDATE
337 | StatementType::MOVE
338 | StatementType::COPY
339 | StatementType::FETCH
340 | StatementType::SELECT
341 | StatementType::INSERT_RETURNING
342 | StatementType::DELETE_RETURNING
343 | StatementType::UPDATE_RETURNING
344 )
345 }
346
347 pub fn is_dml(&self) -> bool {
348 matches!(
349 self,
350 StatementType::INSERT
351 | StatementType::DELETE
352 | StatementType::UPDATE
353 | StatementType::INSERT_RETURNING
354 | StatementType::DELETE_RETURNING
355 | StatementType::UPDATE_RETURNING
356 )
357 }
358
359 pub fn is_query(&self) -> bool {
360 matches!(
361 self,
362 StatementType::SELECT
363 | StatementType::EXPLAIN
364 | StatementType::SHOW_COMMAND
365 | StatementType::SHOW_VARIABLE
366 | StatementType::DESCRIBE
367 | StatementType::INSERT_RETURNING
368 | StatementType::DELETE_RETURNING
369 | StatementType::UPDATE_RETURNING
370 | StatementType::CANCEL_COMMAND
371 | StatementType::FETCH_CURSOR
372 )
373 }
374
375 pub fn is_returning(&self) -> bool {
376 matches!(
377 self,
378 StatementType::INSERT_RETURNING
379 | StatementType::DELETE_RETURNING
380 | StatementType::UPDATE_RETURNING
381 )
382 }
383}
384
385impl<VS> PgResponse<VS>
386where
387 VS: ValuesStream,
388{
389 pub fn builder(stmt_type: StatementType) -> PgResponseBuilder<VS> {
390 PgResponseBuilder::empty(stmt_type)
391 }
392
393 pub fn empty_result(stmt_type: StatementType) -> Self {
394 PgResponseBuilder::empty(stmt_type).into()
395 }
396
397 pub fn stmt_type(&self) -> StatementType {
398 self.stmt_type
399 }
400
401 pub fn notices(&self) -> &[String] {
402 &self.notices
403 }
404
405 pub fn status(&self) -> &ParameterStatus {
406 &self.status
407 }
408
409 pub fn affected_rows_cnt(&self) -> Option<i32> {
410 self.row_cnt
411 }
412
413 pub fn row_cnt_format(&self) -> Option<Format> {
414 self.row_cnt_format
415 }
416
417 pub fn is_query(&self) -> bool {
418 self.stmt_type.is_query()
419 }
420
421 pub fn is_empty(&self) -> bool {
422 self.stmt_type == StatementType::EMPTY
423 }
424
425 pub fn row_desc(&self) -> Vec<PgFieldDescriptor> {
426 self.row_desc.clone()
427 }
428
429 pub fn values_stream(&mut self) -> &mut VS {
430 self.values_stream.as_mut().expect("no values stream")
431 }
432
433 pub async fn run_callback(&mut self) -> Result<(), PsqlError> {
438 if let Some(values_stream) = &mut self.values_stream {
440 assert!(values_stream.next().await.is_none());
441 }
442
443 if let Some(callback) = self.callback.take() {
444 callback.await.map_err(PsqlError::SimpleQueryError)?;
445 }
446 Ok(())
447 }
448}