risingwave_sqlsmith/test_runners/
utils.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{anyhow, bail};
16use itertools::Itertools;
17use rand::rngs::SmallRng;
18use rand::{Rng, SeedableRng};
19#[cfg(madsim)]
20use rand_chacha::ChaChaRng;
21use risingwave_sqlparser::ast::Statement;
22use tokio::time::{Duration, sleep, timeout};
23use tokio_postgres::error::Error as PgError;
24use tokio_postgres::{Client, SimpleQueryMessage};
25
26use crate::config::Configuration;
27use crate::utils::read_file_contents;
28use crate::validation::{is_permissible_error, is_recovery_in_progress_error};
29use crate::{
30    Table, generate_update_statements, insert_sql_gen, mview_sql_gen,
31    parse_create_table_statements, parse_sql, session_sql_gen, sql_gen,
32};
33
34pub(super) type PgResult<A> = std::result::Result<A, PgError>;
35pub(super) type Result<A> = anyhow::Result<A>;
36
37pub(super) async fn update_base_tables<R: Rng>(
38    client: &Client,
39    rng: &mut R,
40    base_tables: &[Table],
41    inserts: &[Statement],
42    config: &Configuration,
43) {
44    let update_statements = generate_update_statements(rng, base_tables, inserts, config).unwrap();
45    for update_statement in update_statements {
46        let sql = update_statement.to_string();
47        tracing::info!("[EXECUTING UPDATES]: {}", &sql);
48        client.simple_query(&sql).await.unwrap();
49    }
50}
51
52pub(super) async fn populate_tables<R: Rng>(
53    client: &Client,
54    rng: &mut R,
55    base_tables: Vec<Table>,
56    row_count: usize,
57    config: &Configuration,
58) -> Vec<Statement> {
59    let inserts = insert_sql_gen(rng, base_tables, row_count, config);
60    for insert in &inserts {
61        tracing::info!("[EXECUTING INSERT]: {}", insert);
62        client.simple_query(insert).await.unwrap();
63    }
64    inserts
65        .iter()
66        .map(|s| parse_sql(s).into_iter().next().unwrap())
67        .collect_vec()
68}
69
70pub(super) async fn set_variable(client: &Client, variable: &str, value: &str) -> String {
71    let s = format!("SET {variable} TO {value}");
72    tracing::info!("[EXECUTING SET_VAR]: {}", s);
73    client.simple_query(&s).await.unwrap();
74    s
75}
76
77/// Sanity checks for sqlsmith
78pub(super) async fn test_sqlsmith<R: Rng>(
79    client: &Client,
80    rng: &mut R,
81    tables: Vec<Table>,
82    base_tables: Vec<Table>,
83    row_count: usize,
84    config: &Configuration,
85) {
86    // Test inserted rows should be at least 50% population count,
87    // otherwise we don't have sufficient data in our system.
88    // ENABLE: https://github.com/risingwavelabs/risingwave/issues/3844
89    test_population_count(client, base_tables, row_count).await;
90    tracing::info!("passed population count test");
91
92    let threshold = 0.50; // permit at most 50% of queries to be skipped.
93    let sample_size = 20;
94
95    let skipped_percentage = test_batch_queries(client, rng, tables.clone(), sample_size, config)
96        .await
97        .unwrap();
98    tracing::info!(
99        "percentage of skipped batch queries = {}, threshold: {}",
100        skipped_percentage,
101        threshold
102    );
103    if skipped_percentage > threshold {
104        panic!("skipped batch queries exceeded threshold.");
105    }
106
107    let skipped_percentage = test_stream_queries(client, rng, tables.clone(), sample_size, config)
108        .await
109        .unwrap();
110    tracing::info!(
111        "percentage of skipped stream queries = {}, threshold: {}",
112        skipped_percentage,
113        threshold
114    );
115    if skipped_percentage > threshold {
116        panic!("skipped stream queries exceeded threshold.");
117    }
118}
119
120pub(super) async fn test_session_variable<R: Rng>(client: &Client, rng: &mut R) -> String {
121    let session_sql = session_sql_gen(rng);
122    tracing::info!("[EXECUTING TEST SESSION_VAR]: {}", session_sql);
123    client.simple_query(session_sql.as_str()).await.unwrap();
124    session_sql
125}
126
127/// Expects at least 50% of inserted rows included.
128pub(super) async fn test_population_count(
129    client: &Client,
130    base_tables: Vec<Table>,
131    expected_count: usize,
132) {
133    let mut actual_count = 0;
134    for t in base_tables {
135        let q = format!("select * from {};", t.name);
136        let rows = client.simple_query(&q).await.unwrap();
137        actual_count += rows.len();
138    }
139    if actual_count < expected_count / 2 {
140        panic!(
141            "expected at least 50% rows included.\
142             Total {} rows, only had {} rows",
143            expected_count, actual_count,
144        )
145    }
146}
147
148/// Test batch queries, returns skipped query statistics
149/// Runs in distributed mode, since queries can be complex and cause overflow in local execution
150/// mode.
151pub(super) async fn test_batch_queries<R: Rng>(
152    client: &Client,
153    rng: &mut R,
154    tables: Vec<Table>,
155    sample_size: usize,
156    config: &Configuration,
157) -> Result<f64> {
158    let mut skipped = 0;
159    for _ in 0..sample_size {
160        test_session_variable(client, rng).await;
161        let sql = sql_gen(rng, tables.clone(), config);
162        tracing::info!("[TEST BATCH]: {}", sql);
163        skipped += run_query(30, client, &sql).await?;
164    }
165    Ok(skipped as f64 / sample_size as f64)
166}
167
168/// Test stream queries, returns skipped query statistics
169pub(super) async fn test_stream_queries<R: Rng>(
170    client: &Client,
171    rng: &mut R,
172    tables: Vec<Table>,
173    sample_size: usize,
174    config: &Configuration,
175) -> Result<f64> {
176    let mut skipped = 0;
177
178    for _ in 0..sample_size {
179        test_session_variable(client, rng).await;
180        let (sql, table) = mview_sql_gen(rng, tables.clone(), "stream_query", config);
181        tracing::info!("[TEST STREAM]: {}", sql);
182        skipped += run_query(12, client, &sql).await?;
183        tracing::info!("[TEST DROP MVIEW]: {}", &format_drop_mview(&table));
184        drop_mview_table(&table, client).await;
185    }
186    Ok(skipped as f64 / sample_size as f64)
187}
188
189pub(super) fn get_seed_table_sql(testdata: &str) -> String {
190    let seed_files = ["tpch.sql", "nexmark.sql", "alltypes.sql"];
191    seed_files
192        .iter()
193        .map(|filename| read_file_contents(format!("{}/{}", testdata, filename)).unwrap())
194        .collect::<String>()
195}
196
197/// Create the tables defined in testdata, along with some mviews.
198/// TODO: Generate indexes and sinks.
199pub(super) async fn create_base_tables(testdata: &str, client: &Client) -> Result<Vec<Table>> {
200    tracing::info!("Preparing tables...");
201
202    let sql = get_seed_table_sql(testdata);
203    let (base_tables, statements) = parse_create_table_statements(sql);
204    let mut mvs_and_base_tables = vec![];
205    mvs_and_base_tables.extend_from_slice(&base_tables);
206
207    for stmt in &statements {
208        let create_sql = stmt.to_string();
209        tracing::info!("[EXECUTING CREATE TABLE]: {}", &create_sql);
210        client.simple_query(&create_sql).await.unwrap();
211    }
212
213    Ok(base_tables)
214}
215
216/// Create the tables defined in testdata, along with some mviews.
217/// TODO: Generate indexes and sinks.
218pub(super) async fn create_mviews(
219    rng: &mut impl Rng,
220    mvs_and_base_tables: Vec<Table>,
221    client: &Client,
222    config: &Configuration,
223) -> Result<(Vec<Table>, Vec<Table>)> {
224    let mut mvs_and_base_tables = mvs_and_base_tables;
225    let mut mviews = vec![];
226    // Generate some mviews
227    for i in 0..20 {
228        let (create_sql, table) =
229            mview_sql_gen(rng, mvs_and_base_tables.clone(), &format!("m{}", i), config);
230        tracing::info!("[EXECUTING CREATE MVIEW]: {}", &create_sql);
231        let skip_count = run_query(6, client, &create_sql).await?;
232        if skip_count == 0 {
233            mvs_and_base_tables.push(table.clone());
234            mviews.push(table);
235        }
236    }
237    Ok((mvs_and_base_tables, mviews))
238}
239
240pub(super) fn format_drop_mview(mview: &Table) -> String {
241    format!("DROP MATERIALIZED VIEW IF EXISTS {}", mview.name)
242}
243
244/// Drops mview tables.
245pub(super) async fn drop_mview_table(mview: &Table, client: &Client) {
246    client
247        .simple_query(&format_drop_mview(mview))
248        .await
249        .unwrap();
250}
251
252/// Drops mview tables and seed tables
253pub(super) async fn drop_tables(mviews: &[Table], testdata: &str, client: &Client) {
254    tracing::info!("Cleaning tables...");
255
256    for mview in mviews.iter().rev() {
257        drop_mview_table(mview, client).await;
258    }
259
260    let seed_files = ["drop_tpch.sql", "drop_nexmark.sql", "drop_alltypes.sql"];
261    let sql = seed_files
262        .iter()
263        .map(|filename| read_file_contents(format!("{}/{}", testdata, filename)).unwrap())
264        .collect::<String>();
265
266    for stmt in sql.lines() {
267        client.simple_query(stmt).await.unwrap();
268    }
269}
270
271/// Validate client responses, returning a count of skipped queries, number of result rows.
272pub(super) fn validate_response(
273    response: PgResult<Vec<SimpleQueryMessage>>,
274) -> Result<(i64, Vec<SimpleQueryMessage>)> {
275    match response {
276        Ok(rows) => Ok((0, rows)),
277        Err(e) => {
278            // Permit runtime errors conservatively.
279            if let Some(e) = e.as_db_error()
280                && is_permissible_error(&e.to_string())
281            {
282                tracing::info!("[SKIPPED ERROR]: {:#?}", e);
283                return Ok((1, vec![]));
284            }
285            // consolidate error reason for deterministic test
286            tracing::info!("[UNEXPECTED ERROR]: {:#?}", e);
287            Err(anyhow!("Encountered unexpected error: {e}"))
288        }
289    }
290}
291
292pub(super) async fn run_query(timeout_duration: u64, client: &Client, query: &str) -> Result<i64> {
293    let (skipped_count, _) = run_query_inner(timeout_duration, client, query, true).await?;
294    Ok(skipped_count)
295}
296
297/// Run query, handle permissible errors
298/// For recovery error, just do bounded retry.
299/// For other errors, validate them accordingly, skipping if they are permitted.
300/// Otherwise just return success.
301/// If takes too long return the query which timed out + execution time + timeout error
302/// Returns: Number of skipped queries, number of rows returned.
303pub(super) async fn run_query_inner(
304    timeout_duration: u64,
305    client: &Client,
306    query: &str,
307    skip_timeout: bool,
308) -> Result<(i64, Vec<SimpleQueryMessage>)> {
309    let query_task = client.simple_query(query);
310    let result = timeout(Duration::from_secs(timeout_duration), query_task).await;
311    let response = match result {
312        Ok(r) => r,
313        Err(_) => {
314            if skip_timeout {
315                return Ok((1, vec![]));
316            } else {
317                bail!(
318                    "[UNEXPECTED ERROR] Query timeout after {timeout_duration}s:\n{:?}",
319                    query
320                )
321            }
322        }
323    };
324    if let Err(e) = &response
325        && let Some(e) = e.as_db_error()
326    {
327        if is_recovery_in_progress_error(&e.to_string()) {
328            let tries = 5;
329            let interval = 1;
330            for _ in 0..tries {
331                // retry 5 times
332                sleep(Duration::from_secs(interval)).await;
333                let query_task = client.simple_query(query);
334                let response = timeout(Duration::from_secs(timeout_duration), query_task).await;
335                match response {
336                    Ok(Ok(r)) => {
337                        return Ok((0, r));
338                    }
339                    Err(_) => bail!(
340                        "[UNEXPECTED ERROR] Query timeout after {timeout_duration}s:\n{:?}",
341                        query
342                    ),
343                    _ => {}
344                }
345            }
346            bail!(
347                "[UNEXPECTED ERROR] Failed to recover after {tries} tries with interval {interval}s"
348            )
349        } else {
350            return validate_response(response);
351        }
352    }
353    let rows = response?;
354    Ok((0, rows))
355}
356
357pub(super) fn generate_rng(seed: Option<u64>) -> impl Rng {
358    #[cfg(madsim)]
359    if let Some(seed) = seed {
360        ChaChaRng::seed_from_u64(seed)
361    } else {
362        ChaChaRng::from_rng(&mut SmallRng::from_os_rng())
363    }
364    #[cfg(not(madsim))]
365    if let Some(seed) = seed {
366        SmallRng::seed_from_u64(seed)
367    } else {
368        SmallRng::from_os_rng()
369    }
370}