risingwave_sqlsmith/test_runners/
utils.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{anyhow, bail};
16use itertools::Itertools;
17use rand::rngs::SmallRng;
18use rand::{Rng, SeedableRng};
19#[cfg(madsim)]
20use rand_chacha::ChaChaRng;
21use risingwave_sqlparser::ast::Statement;
22use tokio::time::{Duration, sleep, timeout};
23use tokio_postgres::error::Error as PgError;
24use tokio_postgres::{Client, SimpleQueryMessage};
25
26use crate::utils::read_file_contents;
27use crate::validation::{is_permissible_error, is_recovery_in_progress_error};
28use crate::{
29    Table, generate_update_statements, insert_sql_gen, mview_sql_gen,
30    parse_create_table_statements, parse_sql, session_sql_gen, sql_gen,
31};
32
33pub(super) type PgResult<A> = std::result::Result<A, PgError>;
34pub(super) type Result<A> = anyhow::Result<A>;
35
36pub(super) async fn update_base_tables<R: Rng>(
37    client: &Client,
38    rng: &mut R,
39    base_tables: &[Table],
40    inserts: &[Statement],
41) {
42    let update_statements = generate_update_statements(rng, base_tables, inserts).unwrap();
43    for update_statement in update_statements {
44        let sql = update_statement.to_string();
45        tracing::info!("[EXECUTING UPDATES]: {}", &sql);
46        client.simple_query(&sql).await.unwrap();
47    }
48}
49
50pub(super) async fn populate_tables<R: Rng>(
51    client: &Client,
52    rng: &mut R,
53    base_tables: Vec<Table>,
54    row_count: usize,
55) -> Vec<Statement> {
56    let inserts = insert_sql_gen(rng, base_tables, row_count);
57    for insert in &inserts {
58        tracing::info!("[EXECUTING INSERT]: {}", insert);
59        client.simple_query(insert).await.unwrap();
60    }
61    inserts
62        .iter()
63        .map(|s| parse_sql(s).into_iter().next().unwrap())
64        .collect_vec()
65}
66
67pub(super) async fn set_variable(client: &Client, variable: &str, value: &str) -> String {
68    let s = format!("SET {variable} TO {value}");
69    tracing::info!("[EXECUTING SET_VAR]: {}", s);
70    client.simple_query(&s).await.unwrap();
71    s
72}
73
74/// Sanity checks for sqlsmith
75pub(super) async fn test_sqlsmith<R: Rng>(
76    client: &Client,
77    rng: &mut R,
78    tables: Vec<Table>,
79    base_tables: Vec<Table>,
80    row_count: usize,
81) {
82    // Test inserted rows should be at least 50% population count,
83    // otherwise we don't have sufficient data in our system.
84    // ENABLE: https://github.com/risingwavelabs/risingwave/issues/3844
85    test_population_count(client, base_tables, row_count).await;
86    tracing::info!("passed population count test");
87
88    let threshold = 0.50; // permit at most 50% of queries to be skipped.
89    let sample_size = 20;
90
91    let skipped_percentage = test_batch_queries(client, rng, tables.clone(), sample_size)
92        .await
93        .unwrap();
94    tracing::info!(
95        "percentage of skipped batch queries = {}, threshold: {}",
96        skipped_percentage,
97        threshold
98    );
99    if skipped_percentage > threshold {
100        panic!("skipped batch queries exceeded threshold.");
101    }
102
103    let skipped_percentage = test_stream_queries(client, rng, tables.clone(), sample_size)
104        .await
105        .unwrap();
106    tracing::info!(
107        "percentage of skipped stream queries = {}, threshold: {}",
108        skipped_percentage,
109        threshold
110    );
111    if skipped_percentage > threshold {
112        panic!("skipped stream queries exceeded threshold.");
113    }
114}
115
116pub(super) async fn test_session_variable<R: Rng>(client: &Client, rng: &mut R) -> String {
117    let session_sql = session_sql_gen(rng);
118    tracing::info!("[EXECUTING TEST SESSION_VAR]: {}", session_sql);
119    client.simple_query(session_sql.as_str()).await.unwrap();
120    session_sql
121}
122
123/// Expects at least 50% of inserted rows included.
124pub(super) async fn test_population_count(
125    client: &Client,
126    base_tables: Vec<Table>,
127    expected_count: usize,
128) {
129    let mut actual_count = 0;
130    for t in base_tables {
131        let q = format!("select * from {};", t.name);
132        let rows = client.simple_query(&q).await.unwrap();
133        actual_count += rows.len();
134    }
135    if actual_count < expected_count / 2 {
136        panic!(
137            "expected at least 50% rows included.\
138             Total {} rows, only had {} rows",
139            expected_count, actual_count,
140        )
141    }
142}
143
144/// Test batch queries, returns skipped query statistics
145/// Runs in distributed mode, since queries can be complex and cause overflow in local execution
146/// mode.
147pub(super) async fn test_batch_queries<R: Rng>(
148    client: &Client,
149    rng: &mut R,
150    tables: Vec<Table>,
151    sample_size: usize,
152) -> Result<f64> {
153    let mut skipped = 0;
154    for _ in 0..sample_size {
155        test_session_variable(client, rng).await;
156        let sql = sql_gen(rng, tables.clone());
157        tracing::info!("[TEST BATCH]: {}", sql);
158        skipped += run_query(30, client, &sql).await?;
159    }
160    Ok(skipped as f64 / sample_size as f64)
161}
162
163/// Test stream queries, returns skipped query statistics
164pub(super) async fn test_stream_queries<R: Rng>(
165    client: &Client,
166    rng: &mut R,
167    tables: Vec<Table>,
168    sample_size: usize,
169) -> Result<f64> {
170    let mut skipped = 0;
171
172    for _ in 0..sample_size {
173        test_session_variable(client, rng).await;
174        let (sql, table) = mview_sql_gen(rng, tables.clone(), "stream_query");
175        tracing::info!("[TEST STREAM]: {}", sql);
176        skipped += run_query(12, client, &sql).await?;
177        tracing::info!("[TEST DROP MVIEW]: {}", &format_drop_mview(&table));
178        drop_mview_table(&table, client).await;
179    }
180    Ok(skipped as f64 / sample_size as f64)
181}
182
183pub(super) fn get_seed_table_sql(testdata: &str) -> String {
184    let seed_files = ["tpch.sql", "nexmark.sql", "alltypes.sql"];
185    seed_files
186        .iter()
187        .map(|filename| read_file_contents(format!("{}/{}", testdata, filename)).unwrap())
188        .collect::<String>()
189}
190
191/// Create the tables defined in testdata, along with some mviews.
192/// TODO: Generate indexes and sinks.
193pub(super) async fn create_base_tables(testdata: &str, client: &Client) -> Result<Vec<Table>> {
194    tracing::info!("Preparing tables...");
195
196    let sql = get_seed_table_sql(testdata);
197    let (base_tables, statements) = parse_create_table_statements(sql);
198    let mut mvs_and_base_tables = vec![];
199    mvs_and_base_tables.extend_from_slice(&base_tables);
200
201    for stmt in &statements {
202        let create_sql = stmt.to_string();
203        tracing::info!("[EXECUTING CREATE TABLE]: {}", &create_sql);
204        client.simple_query(&create_sql).await.unwrap();
205    }
206
207    Ok(base_tables)
208}
209
210/// Create the tables defined in testdata, along with some mviews.
211/// TODO: Generate indexes and sinks.
212pub(super) async fn create_mviews(
213    rng: &mut impl Rng,
214    mvs_and_base_tables: Vec<Table>,
215    client: &Client,
216) -> Result<(Vec<Table>, Vec<Table>)> {
217    let mut mvs_and_base_tables = mvs_and_base_tables;
218    let mut mviews = vec![];
219    // Generate some mviews
220    for i in 0..20 {
221        let (create_sql, table) =
222            mview_sql_gen(rng, mvs_and_base_tables.clone(), &format!("m{}", i));
223        tracing::info!("[EXECUTING CREATE MVIEW]: {}", &create_sql);
224        let skip_count = run_query(6, client, &create_sql).await?;
225        if skip_count == 0 {
226            mvs_and_base_tables.push(table.clone());
227            mviews.push(table);
228        }
229    }
230    Ok((mvs_and_base_tables, mviews))
231}
232
233pub(super) fn format_drop_mview(mview: &Table) -> String {
234    format!("DROP MATERIALIZED VIEW IF EXISTS {}", mview.name)
235}
236
237/// Drops mview tables.
238pub(super) async fn drop_mview_table(mview: &Table, client: &Client) {
239    client
240        .simple_query(&format_drop_mview(mview))
241        .await
242        .unwrap();
243}
244
245/// Drops mview tables and seed tables
246pub(super) async fn drop_tables(mviews: &[Table], testdata: &str, client: &Client) {
247    tracing::info!("Cleaning tables...");
248
249    for mview in mviews.iter().rev() {
250        drop_mview_table(mview, client).await;
251    }
252
253    let seed_files = ["drop_tpch.sql", "drop_nexmark.sql", "drop_alltypes.sql"];
254    let sql = seed_files
255        .iter()
256        .map(|filename| read_file_contents(format!("{}/{}", testdata, filename)).unwrap())
257        .collect::<String>();
258
259    for stmt in sql.lines() {
260        client.simple_query(stmt).await.unwrap();
261    }
262}
263
264/// Validate client responses, returning a count of skipped queries, number of result rows.
265pub(super) fn validate_response(
266    response: PgResult<Vec<SimpleQueryMessage>>,
267) -> Result<(i64, Vec<SimpleQueryMessage>)> {
268    match response {
269        Ok(rows) => Ok((0, rows)),
270        Err(e) => {
271            // Permit runtime errors conservatively.
272            if let Some(e) = e.as_db_error()
273                && is_permissible_error(&e.to_string())
274            {
275                tracing::info!("[SKIPPED ERROR]: {:#?}", e);
276                return Ok((1, vec![]));
277            }
278            // consolidate error reason for deterministic test
279            tracing::info!("[UNEXPECTED ERROR]: {:#?}", e);
280            Err(anyhow!("Encountered unexpected error: {e}"))
281        }
282    }
283}
284
285pub(super) async fn run_query(timeout_duration: u64, client: &Client, query: &str) -> Result<i64> {
286    let (skipped_count, _) = run_query_inner(timeout_duration, client, query, true).await?;
287    Ok(skipped_count)
288}
289
290/// Run query, handle permissible errors
291/// For recovery error, just do bounded retry.
292/// For other errors, validate them accordingly, skipping if they are permitted.
293/// Otherwise just return success.
294/// If takes too long return the query which timed out + execution time + timeout error
295/// Returns: Number of skipped queries, number of rows returned.
296pub(super) async fn run_query_inner(
297    timeout_duration: u64,
298    client: &Client,
299    query: &str,
300    skip_timeout: bool,
301) -> Result<(i64, Vec<SimpleQueryMessage>)> {
302    let query_task = client.simple_query(query);
303    let result = timeout(Duration::from_secs(timeout_duration), query_task).await;
304    let response = match result {
305        Ok(r) => r,
306        Err(_) => {
307            if skip_timeout {
308                return Ok((1, vec![]));
309            } else {
310                bail!(
311                    "[UNEXPECTED ERROR] Query timeout after {timeout_duration}s:\n{:?}",
312                    query
313                )
314            }
315        }
316    };
317    if let Err(e) = &response
318        && let Some(e) = e.as_db_error()
319    {
320        if is_recovery_in_progress_error(&e.to_string()) {
321            let tries = 5;
322            let interval = 1;
323            for _ in 0..tries {
324                // retry 5 times
325                sleep(Duration::from_secs(interval)).await;
326                let query_task = client.simple_query(query);
327                let response = timeout(Duration::from_secs(timeout_duration), query_task).await;
328                match response {
329                    Ok(Ok(r)) => {
330                        return Ok((0, r));
331                    }
332                    Err(_) => bail!(
333                        "[UNEXPECTED ERROR] Query timeout after {timeout_duration}s:\n{:?}",
334                        query
335                    ),
336                    _ => {}
337                }
338            }
339            bail!(
340                "[UNEXPECTED ERROR] Failed to recover after {tries} tries with interval {interval}s"
341            )
342        } else {
343            return validate_response(response);
344        }
345    }
346    let rows = response?;
347    Ok((0, rows))
348}
349
350pub(super) fn generate_rng(seed: Option<u64>) -> impl Rng {
351    #[cfg(madsim)]
352    if let Some(seed) = seed {
353        ChaChaRng::seed_from_u64(seed)
354    } else {
355        ChaChaRng::from_rng(&mut SmallRng::from_os_rng())
356    }
357    #[cfg(not(madsim))]
358    if let Some(seed) = seed {
359        SmallRng::seed_from_u64(seed)
360    } else {
361        SmallRng::from_os_rng()
362    }
363}