risingwave_sqlsmith/test_runners/
utils.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{anyhow, bail};
16use itertools::Itertools;
17use rand::rngs::SmallRng;
18use rand::{Rng, SeedableRng};
19#[cfg(madsim)]
20use rand_chacha::ChaChaRng;
21use risingwave_sqlparser::ast::Statement;
22use tokio::time::{Duration, sleep, timeout};
23use tokio_postgres::error::Error as PgError;
24use tokio_postgres::{Client, SimpleQueryMessage};
25
26use crate::config::Configuration;
27use crate::utils::read_file_contents;
28use crate::validation::{is_permissible_error, is_recovery_in_progress_error};
29use crate::{
30    Table, generate_update_statements, insert_sql_gen, mview_sql_gen,
31    parse_create_table_statements, parse_sql, session_sql_gen, sql_gen,
32};
33
34pub(super) type PgResult<A> = std::result::Result<A, PgError>;
35pub(super) type Result<A> = anyhow::Result<A>;
36
37pub(super) async fn update_base_tables<R: Rng>(
38    client: &Client,
39    rng: &mut R,
40    base_tables: &[Table],
41    inserts: &[Statement],
42    config: &Configuration,
43) {
44    let update_statements = generate_update_statements(rng, base_tables, inserts, config).unwrap();
45    for update_statement in update_statements {
46        let sql = update_statement.to_string();
47        tracing::info!("[EXECUTING UPDATES]: {}", &sql);
48        client.simple_query(&sql).await.unwrap();
49    }
50}
51
52pub(super) async fn populate_tables<R: Rng>(
53    client: &Client,
54    rng: &mut R,
55    base_tables: Vec<Table>,
56    row_count: usize,
57    config: &Configuration,
58) -> Vec<Statement> {
59    let inserts = insert_sql_gen(rng, base_tables, row_count, config);
60    for insert in &inserts {
61        tracing::info!("[EXECUTING INSERT]: {}", insert);
62        client.simple_query(insert).await.unwrap();
63    }
64    inserts
65        .iter()
66        .map(|s| parse_sql(s).into_iter().next().unwrap())
67        .collect_vec()
68}
69
70pub(super) async fn set_variable(client: &Client, variable: &str, value: &str) -> String {
71    let s = format!("SET {variable} TO {value}");
72    tracing::info!("[EXECUTING SET_VAR]: {}", s);
73    client.simple_query(&s).await.unwrap();
74    s
75}
76
77/// Sanity checks for sqlsmith
78pub(super) async fn test_sqlsmith<R: Rng>(
79    client: &Client,
80    rng: &mut R,
81    tables: Vec<Table>,
82    base_tables: Vec<Table>,
83    row_count: usize,
84    config: &Configuration,
85) {
86    // Test inserted rows should be at least 50% population count,
87    // otherwise we don't have sufficient data in our system.
88    // ENABLE: https://github.com/risingwavelabs/risingwave/issues/3844
89    test_population_count(client, base_tables, row_count).await;
90    tracing::info!("passed population count test");
91
92    let threshold = 0.50; // permit at most 50% of queries to be skipped.
93    let sample_size = 20;
94
95    let skipped_percentage = test_batch_queries(client, rng, tables.clone(), sample_size, config)
96        .await
97        .unwrap();
98    tracing::info!(
99        "percentage of skipped batch queries = {}, threshold: {}",
100        skipped_percentage,
101        threshold
102    );
103    if skipped_percentage > threshold {
104        panic!("skipped batch queries exceeded threshold.");
105    }
106
107    let skipped_percentage = test_stream_queries(client, rng, tables.clone(), sample_size, config)
108        .await
109        .unwrap();
110    tracing::info!(
111        "percentage of skipped stream queries = {}, threshold: {}",
112        skipped_percentage,
113        threshold
114    );
115    if skipped_percentage > threshold {
116        panic!("skipped stream queries exceeded threshold.");
117    }
118}
119
120pub(super) async fn test_session_variable<R: Rng>(client: &Client, rng: &mut R) -> String {
121    let session_sql = session_sql_gen(rng);
122    tracing::info!("[EXECUTING TEST SESSION_VAR]: {}", session_sql);
123    client.simple_query(session_sql.as_str()).await.unwrap();
124    session_sql
125}
126
127/// Expects at least 50% of inserted rows included.
128pub(super) async fn test_population_count(
129    client: &Client,
130    base_tables: Vec<Table>,
131    expected_count: usize,
132) {
133    let mut actual_count = 0;
134    for t in base_tables {
135        let q = format!("select * from {};", t.name);
136        let rows = client.simple_query(&q).await.unwrap();
137        actual_count += rows.len();
138    }
139    if actual_count < expected_count / 2 {
140        panic!(
141            "expected at least 50% rows included.\
142             Total {} rows, only had {} rows",
143            expected_count, actual_count,
144        )
145    }
146}
147
148/// Test batch queries, returns skipped query statistics
149/// Runs in distributed mode, since queries can be complex and cause overflow in local execution
150/// mode.
151pub(super) async fn test_batch_queries<R: Rng>(
152    client: &Client,
153    rng: &mut R,
154    tables: Vec<Table>,
155    sample_size: usize,
156    config: &Configuration,
157) -> Result<f64> {
158    let mut skipped = 0;
159    for _ in 0..sample_size {
160        test_session_variable(client, rng).await;
161        let sql = sql_gen(rng, tables.clone(), config);
162        tracing::info!("[TEST BATCH]: {}", sql);
163        skipped += run_query(30, client, &sql).await?;
164    }
165    Ok(skipped as f64 / sample_size as f64)
166}
167
168/// Test stream queries, returns skipped query statistics
169pub(super) async fn test_stream_queries<R: Rng>(
170    client: &Client,
171    rng: &mut R,
172    tables: Vec<Table>,
173    sample_size: usize,
174    config: &Configuration,
175) -> Result<f64> {
176    let mut skipped = 0;
177
178    for _ in 0..sample_size {
179        test_session_variable(client, rng).await;
180        let (sql, table) = mview_sql_gen(rng, tables.clone(), "stream_query", config);
181        tracing::info!("[TEST STREAM]: {}", sql);
182        skipped += run_query(12, client, &sql).await?;
183        tracing::info!("[TEST DROP MVIEW]: {}", &format_drop_mview(&table));
184        drop_mview_table(&table, client).await;
185    }
186    Ok(skipped as f64 / sample_size as f64)
187}
188
189pub(super) fn get_seed_table_sql(testdata: &str) -> String {
190    let seed_files = ["tpch.sql", "nexmark.sql", "alltypes.sql"];
191    seed_files
192        .iter()
193        .map(|filename| read_file_contents(format!("{}/{}", testdata, filename)).unwrap())
194        .collect::<String>()
195}
196
197/// Create the tables defined in testdata, along with some mviews.
198/// TODO: Generate indexes and sinks.
199pub(super) async fn create_base_tables(testdata: &str, client: &Client) -> Result<Vec<Table>> {
200    tracing::info!("Preparing tables");
201
202    let sql = get_seed_table_sql(testdata);
203    let (base_tables, statements) = parse_create_table_statements(sql);
204    let mut mvs_and_base_tables = vec![];
205    mvs_and_base_tables.extend_from_slice(&base_tables);
206
207    for stmt in &statements {
208        let create_sql = stmt.to_string();
209        tracing::info!("[EXECUTING CREATE TABLE]: {}", &create_sql);
210        client.simple_query(&create_sql).await.unwrap();
211    }
212
213    Ok(base_tables)
214}
215
216/// Create the tables defined in testdata, along with some mviews.
217/// TODO: Generate indexes and sinks.
218pub(super) async fn create_mviews(
219    rng: &mut impl Rng,
220    mvs_and_base_tables: Vec<Table>,
221    client: &Client,
222    config: &Configuration,
223) -> Result<(Vec<Table>, Vec<Table>)> {
224    let mut mvs_and_base_tables = mvs_and_base_tables;
225    let mut mviews = vec![];
226    // Generate some mviews
227    for i in 0..20 {
228        let (create_sql, table) =
229            mview_sql_gen(rng, mvs_and_base_tables.clone(), &format!("m{}", i), config);
230        tracing::info!("[EXECUTING CREATE MVIEW]: {}", &create_sql);
231        let skip_count = run_query(6, client, &create_sql).await?;
232        if skip_count == 0 {
233            mvs_and_base_tables.push(table.clone());
234            mviews.push(table);
235        }
236    }
237    Ok((mvs_and_base_tables, mviews))
238}
239
240pub(super) fn format_drop_mview(mview: &Table) -> String {
241    format!("DROP MATERIALIZED VIEW IF EXISTS {}", mview.name)
242}
243
244/// Drops mview tables.
245pub(super) async fn drop_mview_table(mview: &Table, client: &Client) {
246    client
247        .simple_query(&format_drop_mview(mview))
248        .await
249        .unwrap();
250}
251
252/// Drops mview tables and seed tables.
253/// Queries the system catalog to find all existing MVs (catching any that were
254/// created server-side after a client-side timeout). Drops leaf MVs (those with
255/// no dependents) each iteration without CASCADE, so genuine dependency bugs
256/// surface as real errors rather than being silently swallowed.
257pub(super) async fn drop_tables(testdata: &str, client: &Client) {
258    tracing::info!("Cleaning tables");
259
260    // Drop MVs in the public schema using leaf-node iteration: each pass drops
261    // only MVs that no other object currently depends on. This handles inter-MV
262    // dependencies in the correct order without needing CASCADE.
263    loop {
264        let rows = client
265            .simple_query(
266                "SELECT mv.name \
267                 FROM rw_catalog.rw_materialized_views mv \
268                 JOIN rw_catalog.rw_schemas s ON mv.schema_id = s.id \
269                 WHERE s.name = 'public' \
270                   AND NOT EXISTS ( \
271                         SELECT 1 FROM rw_catalog.rw_depend d WHERE d.refobjid = mv.id \
272                       ) \
273                 ORDER BY mv.name",
274            )
275            .await
276            .unwrap();
277        let names: Vec<String> = rows
278            .into_iter()
279            .filter_map(|msg| {
280                if let SimpleQueryMessage::Row(row) = msg {
281                    row.get(0).map(|s| s.to_owned())
282                } else {
283                    None
284                }
285            })
286            .collect();
287
288        if names.is_empty() {
289            // No leaf MVs remain — verify no MVs at all are left.
290            let remaining = client
291                .simple_query(
292                    "SELECT mv.name \
293                     FROM rw_catalog.rw_materialized_views mv \
294                     JOIN rw_catalog.rw_schemas s ON mv.schema_id = s.id \
295                     WHERE s.name = 'public'",
296                )
297                .await
298                .unwrap();
299            let remaining_names: Vec<String> = remaining
300                .into_iter()
301                .filter_map(|msg| {
302                    if let SimpleQueryMessage::Row(row) = msg {
303                        row.get(0).map(|s| s.to_owned())
304                    } else {
305                        None
306                    }
307                })
308                .collect();
309            if !remaining_names.is_empty() {
310                panic!(
311                    "MV cleanup stalled: no leaf MVs but these still exist: {:?}. \
312                     This indicates an unexpected dependency on a non-MV object.",
313                    remaining_names
314                );
315            }
316            break;
317        }
318
319        for name in &names {
320            tracing::info!("Dropping materialized view: {}", name);
321            client
322                .simple_query(&format!("DROP MATERIALIZED VIEW {name}"))
323                .await
324                .unwrap();
325        }
326    }
327
328    let seed_files = ["drop_tpch.sql", "drop_nexmark.sql", "drop_alltypes.sql"];
329    let sql = seed_files
330        .iter()
331        .map(|filename| read_file_contents(format!("{}/{}", testdata, filename)).unwrap())
332        .collect::<String>();
333
334    for stmt in sql.lines() {
335        client.simple_query(stmt).await.unwrap();
336    }
337}
338
339/// Validate client responses, returning a count of skipped queries, number of result rows.
340pub(super) fn validate_response(
341    response: PgResult<Vec<SimpleQueryMessage>>,
342) -> Result<(i64, Vec<SimpleQueryMessage>)> {
343    match response {
344        Ok(rows) => Ok((0, rows)),
345        Err(e) => {
346            // Permit runtime errors conservatively.
347            if let Some(e) = e.as_db_error()
348                && is_permissible_error(&e.to_string())
349            {
350                tracing::info!("[SKIPPED ERROR]: {:#?}", e);
351                return Ok((1, vec![]));
352            }
353            // consolidate error reason for deterministic test
354            tracing::info!("[UNEXPECTED ERROR]: {:#?}", e);
355            Err(anyhow!("Encountered unexpected error: {e}"))
356        }
357    }
358}
359
360pub(super) async fn run_query(timeout_duration: u64, client: &Client, query: &str) -> Result<i64> {
361    let (skipped_count, _) = run_query_inner(timeout_duration, client, query, true).await?;
362    Ok(skipped_count)
363}
364
365/// Run query, handle permissible errors
366/// For recovery error, just do bounded retry.
367/// For other errors, validate them accordingly, skipping if they are permitted.
368/// Otherwise just return success.
369/// If takes too long return the query which timed out + execution time + timeout error
370/// Returns: Number of skipped queries, number of rows returned.
371pub(super) async fn run_query_inner(
372    timeout_duration: u64,
373    client: &Client,
374    query: &str,
375    skip_timeout: bool,
376) -> Result<(i64, Vec<SimpleQueryMessage>)> {
377    let query_task = client.simple_query(query);
378    let result = timeout(Duration::from_secs(timeout_duration), query_task).await;
379    let response = match result {
380        Ok(r) => r,
381        Err(_) => {
382            if skip_timeout {
383                return Ok((1, vec![]));
384            } else {
385                bail!(
386                    "[UNEXPECTED ERROR] Query timeout after {timeout_duration}s:\n{:?}",
387                    query
388                )
389            }
390        }
391    };
392    if let Err(e) = &response
393        && let Some(e) = e.as_db_error()
394    {
395        if is_recovery_in_progress_error(&e.to_string()) {
396            let tries = 5;
397            let interval = 1;
398            for _ in 0..tries {
399                // retry 5 times
400                sleep(Duration::from_secs(interval)).await;
401                let query_task = client.simple_query(query);
402                let response = timeout(Duration::from_secs(timeout_duration), query_task).await;
403                match response {
404                    Ok(Ok(r)) => {
405                        return Ok((0, r));
406                    }
407                    Err(_) => bail!(
408                        "[UNEXPECTED ERROR] Query timeout after {timeout_duration}s:\n{:?}",
409                        query
410                    ),
411                    _ => {}
412                }
413            }
414            bail!(
415                "[UNEXPECTED ERROR] Failed to recover after {tries} tries with interval {interval}s"
416            )
417        } else {
418            return validate_response(response);
419        }
420    }
421    let rows = response?;
422    Ok((0, rows))
423}
424
425pub(super) fn generate_rng(seed: Option<u64>) -> impl Rng {
426    #[cfg(madsim)]
427    if let Some(seed) = seed {
428        ChaChaRng::seed_from_u64(seed)
429    } else {
430        ChaChaRng::from_rng(&mut SmallRng::from_os_rng())
431    }
432    #[cfg(not(madsim))]
433    if let Some(seed) = seed {
434        SmallRng::seed_from_u64(seed)
435    } else {
436        SmallRng::from_os_rng()
437    }
438}