risingwave_e2e_extended_mode_test/
test.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::anyhow;
16use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
17use pg_interval::Interval;
18use rust_decimal::Decimal;
19use tokio::select;
20use tokio_postgres::types::Type;
21use tokio_postgres::{Client, NoTls};
22
23pub struct TestSuite {
24    config: String,
25}
26
27macro_rules! test_eq {
28    ($left:expr, $right:expr $(,)?) => {
29        match (&$left, &$right) {
30            (left_val, right_val) => {
31                if !(*left_val == *right_val) {
32                    return Err(anyhow!(
33                        "{}:{} assertion failed: `(left == right)` \
34                                (left: `{:?}`, right: `{:?}`)",
35                        file!(),
36                        line!(),
37                        left_val,
38                        right_val
39                    ));
40                }
41            }
42        }
43    };
44}
45
46impl TestSuite {
47    pub fn new(
48        db_name: String,
49        user_name: String,
50        server_host: String,
51        server_port: u16,
52        password: String,
53    ) -> Self {
54        let config = if !password.is_empty() {
55            format!(
56                "dbname={} user={} host={} port={} password={}",
57                db_name, user_name, server_host, server_port, password
58            )
59        } else {
60            format!(
61                "dbname={} user={} host={} port={}",
62                db_name, user_name, server_host, server_port
63            )
64        };
65        Self { config }
66    }
67
68    fn init_logger() {
69        let _ = tracing_subscriber::fmt()
70            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
71            .with_ansi(false)
72            .try_init();
73    }
74
75    pub async fn test(&self) -> anyhow::Result<()> {
76        Self::init_logger();
77        self.binary_param_and_result().await?;
78        self.dql_dml_with_param().await?;
79        self.max_row().await?;
80        self.multiple_on_going_portal().await?;
81        self.create_with_parameter().await?;
82        self.simple_cancel(false).await?;
83        self.simple_cancel(true).await?;
84        self.complex_cancel(false).await?;
85        self.complex_cancel(true).await?;
86        self.subquery_with_param().await?;
87        self.create_mview_with_parameter().await?;
88        Ok(())
89    }
90
91    async fn create_client(&self, is_distributed: bool) -> anyhow::Result<Client> {
92        let (client, connection) = tokio_postgres::connect(&self.config, NoTls).await?;
93
94        // The connection object performs the actual communication with the database,
95        // so spawn it off to run on its own.
96        tokio::spawn(async move {
97            if let Err(e) = connection.await {
98                eprintln!("connection error: {}", e);
99            }
100        });
101
102        if is_distributed {
103            client.execute("set query_mode = distributed", &[]).await?;
104        } else {
105            client.execute("set query_mode = local", &[]).await?;
106        }
107
108        Ok(client)
109    }
110
111    pub async fn binary_param_and_result(&self) -> anyhow::Result<()> {
112        let client = self.create_client(false).await?;
113
114        for row in client.query("select $1::SMALLINT;", &[&1024_i16]).await? {
115            let data: i16 = row.try_get(0)?;
116            test_eq!(data, 1024);
117        }
118
119        for row in client.query("select $1::INT;", &[&144232_i32]).await? {
120            let data: i32 = row.try_get(0)?;
121            test_eq!(data, 144232);
122        }
123
124        for row in client.query("select $1::BIGINT;", &[&99999999_i64]).await? {
125            let data: i64 = row.try_get(0)?;
126            test_eq!(data, 99999999);
127        }
128
129        for row in client
130            .query(
131                "select $1::DECIMAL;",
132                &[&Decimal::try_from(2.33454_f32).ok()],
133            )
134            .await?
135        {
136            let data: Decimal = row.try_get(0)?;
137            test_eq!(data, Decimal::try_from(2.33454_f32).unwrap());
138        }
139
140        for row in client.query("select $1::BOOL;", &[&true]).await? {
141            let data: bool = row.try_get(0)?;
142            assert!(data);
143        }
144
145        for row in client.query("select $1::REAL;", &[&1.234234_f32]).await? {
146            let data: f32 = row.try_get(0)?;
147            test_eq!(data, 1.234234);
148        }
149
150        for row in client
151            .query("select $1::DOUBLE PRECISION;", &[&234234.23490238483_f64])
152            .await?
153        {
154            let data: f64 = row.try_get(0)?;
155            test_eq!(data, 234234.23490238483);
156        }
157
158        for row in client
159            .query(
160                "select $1::date;",
161                &[&NaiveDate::from_ymd_opt(2022, 1, 1).unwrap()],
162            )
163            .await?
164        {
165            let data: NaiveDate = row.try_get(0)?;
166            test_eq!(data, NaiveDate::from_ymd_opt(2022, 1, 1).unwrap());
167        }
168
169        for row in client
170            .query(
171                "select $1::time",
172                &[&NaiveTime::from_hms_opt(10, 0, 0).unwrap()],
173            )
174            .await?
175        {
176            let data: NaiveTime = row.try_get(0)?;
177            test_eq!(data, NaiveTime::from_hms_opt(10, 0, 0).unwrap());
178        }
179
180        for row in client
181            .query(
182                "select $1::timestamp",
183                &[&NaiveDate::from_ymd_opt(2022, 1, 1)
184                    .unwrap()
185                    .and_hms_opt(10, 0, 0)
186                    .unwrap()],
187            )
188            .await?
189        {
190            let data: NaiveDateTime = row.try_get(0)?;
191            test_eq!(
192                data,
193                NaiveDate::from_ymd_opt(2022, 1, 1)
194                    .unwrap()
195                    .and_hms_opt(10, 0, 0)
196                    .unwrap()
197            );
198        }
199
200        let timestamptz = DateTime::<Utc>::from_naive_utc_and_offset(
201            NaiveDate::from_ymd_opt(2022, 1, 1)
202                .unwrap()
203                .and_hms_opt(10, 0, 0)
204                .unwrap(),
205            Utc,
206        );
207        for row in client
208            .query("select $1::timestamptz", &[&timestamptz])
209            .await?
210        {
211            let data: DateTime<Utc> = row.try_get(0)?;
212            test_eq!(data, timestamptz);
213        }
214
215        for row in client
216            .query("select $1::interval", &[&Interval::new(1, 1, 24000000)])
217            .await?
218        {
219            let data: Interval = row.try_get(0)?;
220            test_eq!(data, Interval::new(1, 1, 24000000));
221        }
222
223        Ok(())
224    }
225
226    async fn dql_dml_with_param(&self) -> anyhow::Result<()> {
227        let client = self.create_client(false).await?;
228
229        client.query("create table t(id int)", &[]).await?;
230
231        let insert_statement = client
232            .prepare_typed("insert INTO t (id) VALUES ($1)", &[])
233            .await?;
234
235        for i in 0..20 {
236            client.execute(&insert_statement, &[&i]).await?;
237        }
238        client.execute("flush", &[]).await?;
239
240        let update_statement = client
241            .prepare_typed(
242                "update t set id = $1 where id < $2",
243                &[Type::INT4, Type::INT4],
244            )
245            .await?;
246        let query_statement = client
247            .prepare_typed(
248                "select * FROM t where id < $1 order by id ASC",
249                &[Type::INT4],
250            )
251            .await?;
252        let delete_statement = client
253            .prepare_typed("delete FROM t where id < $1", &[Type::INT4])
254            .await?;
255
256        let mut i = 0;
257        for row in client.query(&query_statement, &[&10_i32]).await? {
258            let id: i32 = row.try_get(0)?;
259            test_eq!(id, i);
260            i += 1;
261        }
262        test_eq!(i, 10);
263
264        client
265            .execute(&update_statement, &[&100_i32, &10_i32])
266            .await?;
267        client.execute("flush", &[]).await?;
268
269        let mut i = 0;
270        for _ in client.query(&query_statement, &[&10_i32]).await? {
271            i += 1;
272        }
273        test_eq!(i, 0);
274
275        client.execute(&delete_statement, &[&20_i32]).await?;
276        client.execute("flush", &[]).await?;
277
278        let mut i = 0;
279        for row in client.query(&query_statement, &[&101_i32]).await? {
280            let id: i32 = row.try_get(0)?;
281            test_eq!(id, 100);
282            i += 1;
283        }
284        test_eq!(i, 10);
285
286        client.execute("drop table t", &[]).await?;
287
288        Ok(())
289    }
290
291    async fn max_row(&self) -> anyhow::Result<()> {
292        let mut client = self.create_client(false).await?;
293
294        client.query("create table t(id int)", &[]).await?;
295
296        let insert_statement = client
297            .prepare_typed("insert INTO t (id) VALUES ($1)", &[])
298            .await?;
299
300        for i in 0..10 {
301            client.execute(&insert_statement, &[&i]).await?;
302        }
303        client.execute("flush", &[]).await?;
304
305        let transaction = client.transaction().await?;
306        let statement = transaction
307            .prepare_typed("SELECT * FROM t order by id", &[])
308            .await?;
309        let portal = transaction.bind(&statement, &[]).await?;
310
311        for t in 0..5 {
312            let rows = transaction.query_portal(&portal, 1).await?;
313            test_eq!(rows.len(), 1);
314            let row = rows.first().unwrap();
315            let id: i32 = row.get(0);
316            test_eq!(id, t);
317        }
318
319        let mut i = 5;
320        for row in transaction.query_portal(&portal, 3).await? {
321            let id: i32 = row.get(0);
322            test_eq!(id, i);
323            i += 1;
324        }
325        test_eq!(i, 8);
326
327        for row in transaction.query_portal(&portal, 5).await? {
328            let id: i32 = row.get(0);
329            test_eq!(id, i);
330            i += 1;
331        }
332        test_eq!(i, 10);
333
334        transaction.rollback().await?;
335
336        client.execute("drop table t", &[]).await?;
337
338        Ok(())
339    }
340
341    async fn multiple_on_going_portal(&self) -> anyhow::Result<()> {
342        let mut client = self.create_client(false).await?;
343
344        let transaction = client.transaction().await?;
345        let statement = transaction
346            .prepare_typed("SELECT generate_series(1,5,1)", &[])
347            .await?;
348        let portal_1 = transaction.bind(&statement, &[]).await?;
349        let portal_2 = transaction.bind(&statement, &[]).await?;
350
351        let rows = transaction.query_portal(&portal_1, 1).await?;
352        test_eq!(rows.len(), 1);
353        test_eq!(rows.first().unwrap().get::<usize, i32>(0), 1);
354
355        let rows = transaction.query_portal(&portal_2, 1).await?;
356        test_eq!(rows.len(), 1);
357        test_eq!(rows.first().unwrap().get::<usize, i32>(0), 1);
358
359        let rows = transaction.query_portal(&portal_2, 3).await?;
360        test_eq!(rows.len(), 3);
361        test_eq!(rows.first().unwrap().get::<usize, i32>(0), 2);
362        test_eq!(rows.get(1).unwrap().get::<usize, i32>(0), 3);
363        test_eq!(rows.get(2).unwrap().get::<usize, i32>(0), 4);
364
365        let rows = transaction.query_portal(&portal_1, 1).await?;
366        test_eq!(rows.len(), 1);
367        test_eq!(rows.first().unwrap().get::<usize, i32>(0), 2);
368
369        Ok(())
370    }
371
372    // Can't support these sql
373    async fn create_with_parameter(&self) -> anyhow::Result<()> {
374        let client = self.create_client(false).await?;
375
376        test_eq!(
377            client
378                .query("create table t as select $1", &[])
379                .await
380                .is_err(),
381            true
382        );
383        test_eq!(
384            client
385                .query("create view v as select $1", &[])
386                .await
387                .is_err(),
388            true
389        );
390
391        Ok(())
392    }
393
394    async fn create_mview_with_parameter(&self) -> anyhow::Result<()> {
395        let client = self.create_client(false).await?;
396
397        let statement = client
398            .prepare_typed(
399                "create materialized view mv as select $1 as x",
400                &[Type::INT4],
401            )
402            .await?;
403
404        client.execute(&statement, &[&42_i32]).await?;
405
406        let rows = client.query("select * from mv", &[]).await?;
407        test_eq!(rows.len(), 1);
408        test_eq!(rows.first().unwrap().get::<usize, i32>(0), 42);
409
410        // Test renaming mv because it relies on parsing and rewrite the `create MV` query
411        client
412            .execute("alter materialized view mv rename to mv2", &[])
413            .await?;
414
415        let rows = client.query("select * from mv2", &[]).await?;
416        test_eq!(rows.len(), 1);
417        test_eq!(rows.first().unwrap().get::<usize, i32>(0), 42);
418
419        client.execute("drop materialized view mv2", &[]).await?;
420
421        Ok(())
422    }
423
424    async fn simple_cancel(&self, is_distributed: bool) -> anyhow::Result<()> {
425        let client = self.create_client(is_distributed).await?;
426        client.execute("create table t(id int)", &[]).await?;
427
428        let insert_statement = client
429            .prepare_typed("insert INTO t (id) VALUES ($1)", &[])
430            .await?;
431
432        for i in 0..1000 {
433            client.execute(&insert_statement, &[&i]).await?;
434        }
435
436        client.execute("flush", &[]).await?;
437
438        let cancel_token = client.cancel_token();
439
440        let query_handle = tokio::spawn(async move {
441            client.query("select * from t", &[]).await.unwrap();
442        });
443
444        select! {
445            _ = query_handle => {
446                tracing::error!("Failed to cancel query")
447            },
448            _ = cancel_token.cancel_query(NoTls) => {
449                tracing::trace!("Cancel query successfully")
450            },
451        }
452
453        let new_client = self.create_client(is_distributed).await?;
454
455        let rows = new_client
456            .query("select * from t order by id limit 10", &[])
457            .await?;
458
459        test_eq!(rows.len(), 10);
460        for (expect_id, row) in rows.iter().enumerate() {
461            let id: i32 = row.get(0);
462            test_eq!(id, expect_id as i32);
463        }
464
465        new_client.execute("drop table t", &[]).await?;
466
467        Ok(())
468    }
469
470    async fn complex_cancel(&self, is_distributed: bool) -> anyhow::Result<()> {
471        let client = self.create_client(is_distributed).await?;
472
473        client
474            .execute("create table t1(name varchar, id int)", &[])
475            .await?;
476        client
477            .execute("create table t2(name varchar, id int)", &[])
478            .await?;
479        client
480            .execute("create table t3(name varchar, id int)", &[])
481            .await?;
482
483        let insert_statement = client
484            .prepare_typed("insert INTO t1 (name, id) VALUES ($1, $2)", &[])
485            .await?;
486        let insert_statement2 = client
487            .prepare_typed("insert INTO t2 (name, id) VALUES ($1, $2)", &[])
488            .await?;
489        let insert_statement3 = client
490            .prepare_typed("insert INTO t3 (name, id) VALUES ($1, $2)", &[])
491            .await?;
492        for i in 0..1000 {
493            client
494                .execute(&insert_statement, &[&i.to_string(), &i])
495                .await?;
496            client
497                .execute(&insert_statement2, &[&i.to_string(), &i])
498                .await?;
499            client
500                .execute(&insert_statement3, &[&i.to_string(), &i])
501                .await?;
502        }
503
504        client.execute("flush", &[]).await?;
505
506        client.execute("set query_mode=local", &[]).await?;
507
508        let cancel_token = client.cancel_token();
509
510        let query_sql = "SELECT t1.name, t2.id, t3.name
511        FROM t1
512        INNER JOIN (
513          SELECT id, name
514          FROM t2
515          WHERE id IN (
516            SELECT id
517            FROM t1
518            WHERE name LIKE '%1%'
519          )
520        ) AS t2 ON t1.id = t2.id
521        LEFT JOIN t3 ON t2.name = t3.name
522        WHERE t3.id IN (
523          SELECT MAX(id)
524          FROM t3
525          GROUP BY name
526        )
527        ORDER BY t1.name ASC, t3.id DESC
528        ";
529
530        let query_handle = tokio::spawn(async move {
531            let result = client.query(query_sql, &[]).await;
532            match result {
533                Ok(_) => {
534                    tracing::error!("Query should be canceled");
535                }
536                Err(e) => {
537                    tracing::error!("Query failed with error: {:?}", e);
538                }
539            };
540        });
541
542        select! {
543            _ = query_handle => {
544                tracing::error!("Failed to cancel query")
545            },
546            _ = cancel_token.cancel_query(NoTls) => {
547                tracing::info!("Cancel query successfully")
548            },
549        }
550
551        let new_client = self.create_client(is_distributed).await?;
552
553        let rows = new_client
554            .query(&format!("{} LIMIT 10", query_sql), &[])
555            .await?;
556        let expect_ans = [
557            (1, 1, 1),
558            (10, 10, 10),
559            (100, 100, 100),
560            (101, 101, 101),
561            (102, 102, 102),
562            (103, 103, 103),
563            (104, 104, 104),
564            (105, 105, 105),
565            (106, 106, 106),
566            (107, 107, 107),
567        ];
568        for (i, row) in rows.iter().enumerate() {
569            test_eq!(
570                row.get::<_, String>(0).parse::<i32>().unwrap(),
571                expect_ans[i].0
572            );
573            test_eq!(row.get::<_, i32>(1), expect_ans[i].1);
574            test_eq!(
575                row.get::<_, String>(2).parse::<i32>().unwrap(),
576                expect_ans[i].2
577            );
578        }
579
580        new_client.execute("drop table t1", &[]).await?;
581        new_client.execute("drop table t2", &[]).await?;
582        new_client.execute("drop table t3", &[]).await?;
583        Ok(())
584    }
585
586    async fn subquery_with_param(&self) -> anyhow::Result<()> {
587        let client = self.create_client(false).await?;
588
589        let res = client
590            .query("select (select $1::SMALLINT)", &[&1024_i16])
591            .await
592            .unwrap();
593
594        assert_eq!(res[0].get::<usize, i16>(0), 1024_i16);
595
596        Ok(())
597    }
598}