1use anyhow::anyhow;
16use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
17use pg_interval::Interval;
18use rust_decimal::Decimal;
19use tokio::select;
20use tokio_postgres::types::Type;
21use tokio_postgres::{Client, NoTls};
22
23pub struct TestSuite {
24 config: String,
25}
26
27macro_rules! test_eq {
28 ($left:expr, $right:expr $(,)?) => {
29 match (&$left, &$right) {
30 (left_val, right_val) => {
31 if !(*left_val == *right_val) {
32 return Err(anyhow!(
33 "{}:{} assertion failed: `(left == right)` \
34 (left: `{:?}`, right: `{:?}`)",
35 file!(),
36 line!(),
37 left_val,
38 right_val
39 ));
40 }
41 }
42 }
43 };
44}
45
46impl TestSuite {
47 pub fn new(
48 db_name: String,
49 user_name: String,
50 server_host: String,
51 server_port: u16,
52 password: String,
53 ) -> Self {
54 let config = if !password.is_empty() {
55 format!(
56 "dbname={} user={} host={} port={} password={}",
57 db_name, user_name, server_host, server_port, password
58 )
59 } else {
60 format!(
61 "dbname={} user={} host={} port={}",
62 db_name, user_name, server_host, server_port
63 )
64 };
65 Self { config }
66 }
67
68 fn init_logger() {
69 let _ = tracing_subscriber::fmt()
70 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
71 .with_ansi(false)
72 .try_init();
73 }
74
75 pub async fn test(&self) -> anyhow::Result<()> {
76 Self::init_logger();
77 self.binary_param_and_result().await?;
78 self.dql_dml_with_param().await?;
79 self.max_row().await?;
80 self.multiple_on_going_portal().await?;
81 self.create_with_parameter().await?;
82 self.simple_cancel(false).await?;
83 self.simple_cancel(true).await?;
84 self.complex_cancel(false).await?;
85 self.complex_cancel(true).await?;
86 self.subquery_with_param().await?;
87 self.create_mview_with_parameter().await?;
88 Ok(())
89 }
90
91 async fn create_client(&self, is_distributed: bool) -> anyhow::Result<Client> {
92 let (client, connection) = tokio_postgres::connect(&self.config, NoTls).await?;
93
94 tokio::spawn(async move {
97 if let Err(e) = connection.await {
98 eprintln!("connection error: {}", e);
99 }
100 });
101
102 if is_distributed {
103 client.execute("set query_mode = distributed", &[]).await?;
104 } else {
105 client.execute("set query_mode = local", &[]).await?;
106 }
107
108 Ok(client)
109 }
110
111 pub async fn binary_param_and_result(&self) -> anyhow::Result<()> {
112 let client = self.create_client(false).await?;
113
114 for row in client.query("select $1::SMALLINT;", &[&1024_i16]).await? {
115 let data: i16 = row.try_get(0)?;
116 test_eq!(data, 1024);
117 }
118
119 for row in client.query("select $1::INT;", &[&144232_i32]).await? {
120 let data: i32 = row.try_get(0)?;
121 test_eq!(data, 144232);
122 }
123
124 for row in client.query("select $1::BIGINT;", &[&99999999_i64]).await? {
125 let data: i64 = row.try_get(0)?;
126 test_eq!(data, 99999999);
127 }
128
129 for row in client
130 .query(
131 "select $1::DECIMAL;",
132 &[&Decimal::try_from(2.33454_f32).ok()],
133 )
134 .await?
135 {
136 let data: Decimal = row.try_get(0)?;
137 test_eq!(data, Decimal::try_from(2.33454_f32).unwrap());
138 }
139
140 for row in client.query("select $1::BOOL;", &[&true]).await? {
141 let data: bool = row.try_get(0)?;
142 assert!(data);
143 }
144
145 for row in client.query("select $1::REAL;", &[&1.234234_f32]).await? {
146 let data: f32 = row.try_get(0)?;
147 test_eq!(data, 1.234234);
148 }
149
150 for row in client
151 .query("select $1::DOUBLE PRECISION;", &[&234234.23490238483_f64])
152 .await?
153 {
154 let data: f64 = row.try_get(0)?;
155 test_eq!(data, 234234.23490238483);
156 }
157
158 for row in client
159 .query(
160 "select $1::date;",
161 &[&NaiveDate::from_ymd_opt(2022, 1, 1).unwrap()],
162 )
163 .await?
164 {
165 let data: NaiveDate = row.try_get(0)?;
166 test_eq!(data, NaiveDate::from_ymd_opt(2022, 1, 1).unwrap());
167 }
168
169 for row in client
170 .query(
171 "select $1::time",
172 &[&NaiveTime::from_hms_opt(10, 0, 0).unwrap()],
173 )
174 .await?
175 {
176 let data: NaiveTime = row.try_get(0)?;
177 test_eq!(data, NaiveTime::from_hms_opt(10, 0, 0).unwrap());
178 }
179
180 for row in client
181 .query(
182 "select $1::timestamp",
183 &[&NaiveDate::from_ymd_opt(2022, 1, 1)
184 .unwrap()
185 .and_hms_opt(10, 0, 0)
186 .unwrap()],
187 )
188 .await?
189 {
190 let data: NaiveDateTime = row.try_get(0)?;
191 test_eq!(
192 data,
193 NaiveDate::from_ymd_opt(2022, 1, 1)
194 .unwrap()
195 .and_hms_opt(10, 0, 0)
196 .unwrap()
197 );
198 }
199
200 let timestamptz = DateTime::<Utc>::from_naive_utc_and_offset(
201 NaiveDate::from_ymd_opt(2022, 1, 1)
202 .unwrap()
203 .and_hms_opt(10, 0, 0)
204 .unwrap(),
205 Utc,
206 );
207 for row in client
208 .query("select $1::timestamptz", &[×tamptz])
209 .await?
210 {
211 let data: DateTime<Utc> = row.try_get(0)?;
212 test_eq!(data, timestamptz);
213 }
214
215 for row in client
216 .query("select $1::interval", &[&Interval::new(1, 1, 24000000)])
217 .await?
218 {
219 let data: Interval = row.try_get(0)?;
220 test_eq!(data, Interval::new(1, 1, 24000000));
221 }
222
223 Ok(())
224 }
225
226 async fn dql_dml_with_param(&self) -> anyhow::Result<()> {
227 let client = self.create_client(false).await?;
228
229 client.query("create table t(id int)", &[]).await?;
230
231 let insert_statement = client
232 .prepare_typed("insert INTO t (id) VALUES ($1)", &[])
233 .await?;
234
235 for i in 0..20 {
236 client.execute(&insert_statement, &[&i]).await?;
237 }
238 client.execute("flush", &[]).await?;
239
240 let update_statement = client
241 .prepare_typed(
242 "update t set id = $1 where id < $2",
243 &[Type::INT4, Type::INT4],
244 )
245 .await?;
246 let query_statement = client
247 .prepare_typed(
248 "select * FROM t where id < $1 order by id ASC",
249 &[Type::INT4],
250 )
251 .await?;
252 let delete_statement = client
253 .prepare_typed("delete FROM t where id < $1", &[Type::INT4])
254 .await?;
255
256 let mut i = 0;
257 for row in client.query(&query_statement, &[&10_i32]).await? {
258 let id: i32 = row.try_get(0)?;
259 test_eq!(id, i);
260 i += 1;
261 }
262 test_eq!(i, 10);
263
264 client
265 .execute(&update_statement, &[&100_i32, &10_i32])
266 .await?;
267 client.execute("flush", &[]).await?;
268
269 let mut i = 0;
270 for _ in client.query(&query_statement, &[&10_i32]).await? {
271 i += 1;
272 }
273 test_eq!(i, 0);
274
275 client.execute(&delete_statement, &[&20_i32]).await?;
276 client.execute("flush", &[]).await?;
277
278 let mut i = 0;
279 for row in client.query(&query_statement, &[&101_i32]).await? {
280 let id: i32 = row.try_get(0)?;
281 test_eq!(id, 100);
282 i += 1;
283 }
284 test_eq!(i, 10);
285
286 client.execute("drop table t", &[]).await?;
287
288 Ok(())
289 }
290
291 async fn max_row(&self) -> anyhow::Result<()> {
292 let mut client = self.create_client(false).await?;
293
294 client.query("create table t(id int)", &[]).await?;
295
296 let insert_statement = client
297 .prepare_typed("insert INTO t (id) VALUES ($1)", &[])
298 .await?;
299
300 for i in 0..10 {
301 client.execute(&insert_statement, &[&i]).await?;
302 }
303 client.execute("flush", &[]).await?;
304
305 let transaction = client.transaction().await?;
306 let statement = transaction
307 .prepare_typed("SELECT * FROM t order by id", &[])
308 .await?;
309 let portal = transaction.bind(&statement, &[]).await?;
310
311 for t in 0..5 {
312 let rows = transaction.query_portal(&portal, 1).await?;
313 test_eq!(rows.len(), 1);
314 let row = rows.first().unwrap();
315 let id: i32 = row.get(0);
316 test_eq!(id, t);
317 }
318
319 let mut i = 5;
320 for row in transaction.query_portal(&portal, 3).await? {
321 let id: i32 = row.get(0);
322 test_eq!(id, i);
323 i += 1;
324 }
325 test_eq!(i, 8);
326
327 for row in transaction.query_portal(&portal, 5).await? {
328 let id: i32 = row.get(0);
329 test_eq!(id, i);
330 i += 1;
331 }
332 test_eq!(i, 10);
333
334 transaction.rollback().await?;
335
336 client.execute("drop table t", &[]).await?;
337
338 Ok(())
339 }
340
341 async fn multiple_on_going_portal(&self) -> anyhow::Result<()> {
342 let mut client = self.create_client(false).await?;
343
344 let transaction = client.transaction().await?;
345 let statement = transaction
346 .prepare_typed("SELECT generate_series(1,5,1)", &[])
347 .await?;
348 let portal_1 = transaction.bind(&statement, &[]).await?;
349 let portal_2 = transaction.bind(&statement, &[]).await?;
350
351 let rows = transaction.query_portal(&portal_1, 1).await?;
352 test_eq!(rows.len(), 1);
353 test_eq!(rows.first().unwrap().get::<usize, i32>(0), 1);
354
355 let rows = transaction.query_portal(&portal_2, 1).await?;
356 test_eq!(rows.len(), 1);
357 test_eq!(rows.first().unwrap().get::<usize, i32>(0), 1);
358
359 let rows = transaction.query_portal(&portal_2, 3).await?;
360 test_eq!(rows.len(), 3);
361 test_eq!(rows.first().unwrap().get::<usize, i32>(0), 2);
362 test_eq!(rows.get(1).unwrap().get::<usize, i32>(0), 3);
363 test_eq!(rows.get(2).unwrap().get::<usize, i32>(0), 4);
364
365 let rows = transaction.query_portal(&portal_1, 1).await?;
366 test_eq!(rows.len(), 1);
367 test_eq!(rows.first().unwrap().get::<usize, i32>(0), 2);
368
369 Ok(())
370 }
371
372 async fn create_with_parameter(&self) -> anyhow::Result<()> {
374 let client = self.create_client(false).await?;
375
376 test_eq!(
377 client
378 .query("create table t as select $1", &[])
379 .await
380 .is_err(),
381 true
382 );
383 test_eq!(
384 client
385 .query("create view v as select $1", &[])
386 .await
387 .is_err(),
388 true
389 );
390
391 Ok(())
392 }
393
394 async fn create_mview_with_parameter(&self) -> anyhow::Result<()> {
395 let client = self.create_client(false).await?;
396
397 let statement = client
398 .prepare_typed(
399 "create materialized view mv as select $1 as x",
400 &[Type::INT4],
401 )
402 .await?;
403
404 client.execute(&statement, &[&42_i32]).await?;
405
406 let rows = client.query("select * from mv", &[]).await?;
407 test_eq!(rows.len(), 1);
408 test_eq!(rows.first().unwrap().get::<usize, i32>(0), 42);
409
410 client
412 .execute("alter materialized view mv rename to mv2", &[])
413 .await?;
414
415 let rows = client.query("select * from mv2", &[]).await?;
416 test_eq!(rows.len(), 1);
417 test_eq!(rows.first().unwrap().get::<usize, i32>(0), 42);
418
419 client.execute("drop materialized view mv2", &[]).await?;
420
421 Ok(())
422 }
423
424 async fn simple_cancel(&self, is_distributed: bool) -> anyhow::Result<()> {
425 let client = self.create_client(is_distributed).await?;
426 client.execute("create table t(id int)", &[]).await?;
427
428 let insert_statement = client
429 .prepare_typed("insert INTO t (id) VALUES ($1)", &[])
430 .await?;
431
432 for i in 0..1000 {
433 client.execute(&insert_statement, &[&i]).await?;
434 }
435
436 client.execute("flush", &[]).await?;
437
438 let cancel_token = client.cancel_token();
439
440 let query_handle = tokio::spawn(async move {
441 client.query("select * from t", &[]).await.unwrap();
442 });
443
444 select! {
445 _ = query_handle => {
446 tracing::error!("Failed to cancel query")
447 },
448 _ = cancel_token.cancel_query(NoTls) => {
449 tracing::trace!("Cancel query successfully")
450 },
451 }
452
453 let new_client = self.create_client(is_distributed).await?;
454
455 let rows = new_client
456 .query("select * from t order by id limit 10", &[])
457 .await?;
458
459 test_eq!(rows.len(), 10);
460 for (expect_id, row) in rows.iter().enumerate() {
461 let id: i32 = row.get(0);
462 test_eq!(id, expect_id as i32);
463 }
464
465 new_client.execute("drop table t", &[]).await?;
466
467 Ok(())
468 }
469
470 async fn complex_cancel(&self, is_distributed: bool) -> anyhow::Result<()> {
471 let client = self.create_client(is_distributed).await?;
472
473 client
474 .execute("create table t1(name varchar, id int)", &[])
475 .await?;
476 client
477 .execute("create table t2(name varchar, id int)", &[])
478 .await?;
479 client
480 .execute("create table t3(name varchar, id int)", &[])
481 .await?;
482
483 let insert_statement = client
484 .prepare_typed("insert INTO t1 (name, id) VALUES ($1, $2)", &[])
485 .await?;
486 let insert_statement2 = client
487 .prepare_typed("insert INTO t2 (name, id) VALUES ($1, $2)", &[])
488 .await?;
489 let insert_statement3 = client
490 .prepare_typed("insert INTO t3 (name, id) VALUES ($1, $2)", &[])
491 .await?;
492 for i in 0..1000 {
493 client
494 .execute(&insert_statement, &[&i.to_string(), &i])
495 .await?;
496 client
497 .execute(&insert_statement2, &[&i.to_string(), &i])
498 .await?;
499 client
500 .execute(&insert_statement3, &[&i.to_string(), &i])
501 .await?;
502 }
503
504 client.execute("flush", &[]).await?;
505
506 client.execute("set query_mode=local", &[]).await?;
507
508 let cancel_token = client.cancel_token();
509
510 let query_sql = "SELECT t1.name, t2.id, t3.name
511 FROM t1
512 INNER JOIN (
513 SELECT id, name
514 FROM t2
515 WHERE id IN (
516 SELECT id
517 FROM t1
518 WHERE name LIKE '%1%'
519 )
520 ) AS t2 ON t1.id = t2.id
521 LEFT JOIN t3 ON t2.name = t3.name
522 WHERE t3.id IN (
523 SELECT MAX(id)
524 FROM t3
525 GROUP BY name
526 )
527 ORDER BY t1.name ASC, t3.id DESC
528 ";
529
530 let query_handle = tokio::spawn(async move {
531 let result = client.query(query_sql, &[]).await;
532 match result {
533 Ok(_) => {
534 tracing::error!("Query should be canceled");
535 }
536 Err(e) => {
537 tracing::error!("Query failed with error: {:?}", e);
538 }
539 };
540 });
541
542 select! {
543 _ = query_handle => {
544 tracing::error!("Failed to cancel query")
545 },
546 _ = cancel_token.cancel_query(NoTls) => {
547 tracing::info!("Cancel query successfully")
548 },
549 }
550
551 let new_client = self.create_client(is_distributed).await?;
552
553 let rows = new_client
554 .query(&format!("{} LIMIT 10", query_sql), &[])
555 .await?;
556 let expect_ans = [
557 (1, 1, 1),
558 (10, 10, 10),
559 (100, 100, 100),
560 (101, 101, 101),
561 (102, 102, 102),
562 (103, 103, 103),
563 (104, 104, 104),
564 (105, 105, 105),
565 (106, 106, 106),
566 (107, 107, 107),
567 ];
568 for (i, row) in rows.iter().enumerate() {
569 test_eq!(
570 row.get::<_, String>(0).parse::<i32>().unwrap(),
571 expect_ans[i].0
572 );
573 test_eq!(row.get::<_, i32>(1), expect_ans[i].1);
574 test_eq!(
575 row.get::<_, String>(2).parse::<i32>().unwrap(),
576 expect_ans[i].2
577 );
578 }
579
580 new_client.execute("drop table t1", &[]).await?;
581 new_client.execute("drop table t2", &[]).await?;
582 new_client.execute("drop table t3", &[]).await?;
583 Ok(())
584 }
585
586 async fn subquery_with_param(&self) -> anyhow::Result<()> {
587 let client = self.create_client(false).await?;
588
589 let res = client
590 .query("select (select $1::SMALLINT)", &[&1024_i16])
591 .await
592 .unwrap();
593
594 assert_eq!(res[0].get::<usize, i16>(0), 1024_i16);
595
596 Ok(())
597 }
598}