risingwave_sqlsmith/test_runners/
utils.rs1use anyhow::{anyhow, bail};
16use itertools::Itertools;
17use rand::rngs::SmallRng;
18use rand::{Rng, SeedableRng};
19#[cfg(madsim)]
20use rand_chacha::ChaChaRng;
21use risingwave_sqlparser::ast::Statement;
22use tokio::time::{Duration, sleep, timeout};
23use tokio_postgres::error::Error as PgError;
24use tokio_postgres::{Client, SimpleQueryMessage};
25
26use crate::utils::read_file_contents;
27use crate::validation::{is_permissible_error, is_recovery_in_progress_error};
28use crate::{
29 Table, generate_update_statements, insert_sql_gen, mview_sql_gen,
30 parse_create_table_statements, parse_sql, session_sql_gen, sql_gen,
31};
32
33pub(super) type PgResult<A> = std::result::Result<A, PgError>;
34pub(super) type Result<A> = anyhow::Result<A>;
35
36pub(super) async fn update_base_tables<R: Rng>(
37 client: &Client,
38 rng: &mut R,
39 base_tables: &[Table],
40 inserts: &[Statement],
41) {
42 let update_statements = generate_update_statements(rng, base_tables, inserts).unwrap();
43 for update_statement in update_statements {
44 let sql = update_statement.to_string();
45 tracing::info!("[EXECUTING UPDATES]: {}", &sql);
46 client.simple_query(&sql).await.unwrap();
47 }
48}
49
50pub(super) async fn populate_tables<R: Rng>(
51 client: &Client,
52 rng: &mut R,
53 base_tables: Vec<Table>,
54 row_count: usize,
55) -> Vec<Statement> {
56 let inserts = insert_sql_gen(rng, base_tables, row_count);
57 for insert in &inserts {
58 tracing::info!("[EXECUTING INSERT]: {}", insert);
59 client.simple_query(insert).await.unwrap();
60 }
61 inserts
62 .iter()
63 .map(|s| parse_sql(s).into_iter().next().unwrap())
64 .collect_vec()
65}
66
67pub(super) async fn set_variable(client: &Client, variable: &str, value: &str) -> String {
68 let s = format!("SET {variable} TO {value}");
69 tracing::info!("[EXECUTING SET_VAR]: {}", s);
70 client.simple_query(&s).await.unwrap();
71 s
72}
73
74pub(super) async fn test_sqlsmith<R: Rng>(
76 client: &Client,
77 rng: &mut R,
78 tables: Vec<Table>,
79 base_tables: Vec<Table>,
80 row_count: usize,
81) {
82 test_population_count(client, base_tables, row_count).await;
86 tracing::info!("passed population count test");
87
88 let threshold = 0.50; let sample_size = 20;
90
91 let skipped_percentage = test_batch_queries(client, rng, tables.clone(), sample_size)
92 .await
93 .unwrap();
94 tracing::info!(
95 "percentage of skipped batch queries = {}, threshold: {}",
96 skipped_percentage,
97 threshold
98 );
99 if skipped_percentage > threshold {
100 panic!("skipped batch queries exceeded threshold.");
101 }
102
103 let skipped_percentage = test_stream_queries(client, rng, tables.clone(), sample_size)
104 .await
105 .unwrap();
106 tracing::info!(
107 "percentage of skipped stream queries = {}, threshold: {}",
108 skipped_percentage,
109 threshold
110 );
111 if skipped_percentage > threshold {
112 panic!("skipped stream queries exceeded threshold.");
113 }
114}
115
116pub(super) async fn test_session_variable<R: Rng>(client: &Client, rng: &mut R) -> String {
117 let session_sql = session_sql_gen(rng);
118 tracing::info!("[EXECUTING TEST SESSION_VAR]: {}", session_sql);
119 client.simple_query(session_sql.as_str()).await.unwrap();
120 session_sql
121}
122
123pub(super) async fn test_population_count(
125 client: &Client,
126 base_tables: Vec<Table>,
127 expected_count: usize,
128) {
129 let mut actual_count = 0;
130 for t in base_tables {
131 let q = format!("select * from {};", t.name);
132 let rows = client.simple_query(&q).await.unwrap();
133 actual_count += rows.len();
134 }
135 if actual_count < expected_count / 2 {
136 panic!(
137 "expected at least 50% rows included.\
138 Total {} rows, only had {} rows",
139 expected_count, actual_count,
140 )
141 }
142}
143
144pub(super) async fn test_batch_queries<R: Rng>(
148 client: &Client,
149 rng: &mut R,
150 tables: Vec<Table>,
151 sample_size: usize,
152) -> Result<f64> {
153 let mut skipped = 0;
154 for _ in 0..sample_size {
155 test_session_variable(client, rng).await;
156 let sql = sql_gen(rng, tables.clone());
157 tracing::info!("[TEST BATCH]: {}", sql);
158 skipped += run_query(30, client, &sql).await?;
159 }
160 Ok(skipped as f64 / sample_size as f64)
161}
162
163pub(super) async fn test_stream_queries<R: Rng>(
165 client: &Client,
166 rng: &mut R,
167 tables: Vec<Table>,
168 sample_size: usize,
169) -> Result<f64> {
170 let mut skipped = 0;
171
172 for _ in 0..sample_size {
173 test_session_variable(client, rng).await;
174 let (sql, table) = mview_sql_gen(rng, tables.clone(), "stream_query");
175 tracing::info!("[TEST STREAM]: {}", sql);
176 skipped += run_query(12, client, &sql).await?;
177 tracing::info!("[TEST DROP MVIEW]: {}", &format_drop_mview(&table));
178 drop_mview_table(&table, client).await;
179 }
180 Ok(skipped as f64 / sample_size as f64)
181}
182
183pub(super) fn get_seed_table_sql(testdata: &str) -> String {
184 let seed_files = ["tpch.sql", "nexmark.sql", "alltypes.sql"];
185 seed_files
186 .iter()
187 .map(|filename| read_file_contents(format!("{}/{}", testdata, filename)).unwrap())
188 .collect::<String>()
189}
190
191pub(super) async fn create_base_tables(testdata: &str, client: &Client) -> Result<Vec<Table>> {
194 tracing::info!("Preparing tables...");
195
196 let sql = get_seed_table_sql(testdata);
197 let (base_tables, statements) = parse_create_table_statements(sql);
198 let mut mvs_and_base_tables = vec![];
199 mvs_and_base_tables.extend_from_slice(&base_tables);
200
201 for stmt in &statements {
202 let create_sql = stmt.to_string();
203 tracing::info!("[EXECUTING CREATE TABLE]: {}", &create_sql);
204 client.simple_query(&create_sql).await.unwrap();
205 }
206
207 Ok(base_tables)
208}
209
210pub(super) async fn create_mviews(
213 rng: &mut impl Rng,
214 mvs_and_base_tables: Vec<Table>,
215 client: &Client,
216) -> Result<(Vec<Table>, Vec<Table>)> {
217 let mut mvs_and_base_tables = mvs_and_base_tables;
218 let mut mviews = vec![];
219 for i in 0..20 {
221 let (create_sql, table) =
222 mview_sql_gen(rng, mvs_and_base_tables.clone(), &format!("m{}", i));
223 tracing::info!("[EXECUTING CREATE MVIEW]: {}", &create_sql);
224 let skip_count = run_query(6, client, &create_sql).await?;
225 if skip_count == 0 {
226 mvs_and_base_tables.push(table.clone());
227 mviews.push(table);
228 }
229 }
230 Ok((mvs_and_base_tables, mviews))
231}
232
233pub(super) fn format_drop_mview(mview: &Table) -> String {
234 format!("DROP MATERIALIZED VIEW IF EXISTS {}", mview.name)
235}
236
237pub(super) async fn drop_mview_table(mview: &Table, client: &Client) {
239 client
240 .simple_query(&format_drop_mview(mview))
241 .await
242 .unwrap();
243}
244
245pub(super) async fn drop_tables(mviews: &[Table], testdata: &str, client: &Client) {
247 tracing::info!("Cleaning tables...");
248
249 for mview in mviews.iter().rev() {
250 drop_mview_table(mview, client).await;
251 }
252
253 let seed_files = ["drop_tpch.sql", "drop_nexmark.sql", "drop_alltypes.sql"];
254 let sql = seed_files
255 .iter()
256 .map(|filename| read_file_contents(format!("{}/{}", testdata, filename)).unwrap())
257 .collect::<String>();
258
259 for stmt in sql.lines() {
260 client.simple_query(stmt).await.unwrap();
261 }
262}
263
264pub(super) fn validate_response(
266 response: PgResult<Vec<SimpleQueryMessage>>,
267) -> Result<(i64, Vec<SimpleQueryMessage>)> {
268 match response {
269 Ok(rows) => Ok((0, rows)),
270 Err(e) => {
271 if let Some(e) = e.as_db_error()
273 && is_permissible_error(&e.to_string())
274 {
275 tracing::info!("[SKIPPED ERROR]: {:#?}", e);
276 return Ok((1, vec![]));
277 }
278 tracing::info!("[UNEXPECTED ERROR]: {:#?}", e);
280 Err(anyhow!("Encountered unexpected error: {e}"))
281 }
282 }
283}
284
285pub(super) async fn run_query(timeout_duration: u64, client: &Client, query: &str) -> Result<i64> {
286 let (skipped_count, _) = run_query_inner(timeout_duration, client, query, true).await?;
287 Ok(skipped_count)
288}
289
290pub(super) async fn run_query_inner(
297 timeout_duration: u64,
298 client: &Client,
299 query: &str,
300 skip_timeout: bool,
301) -> Result<(i64, Vec<SimpleQueryMessage>)> {
302 let query_task = client.simple_query(query);
303 let result = timeout(Duration::from_secs(timeout_duration), query_task).await;
304 let response = match result {
305 Ok(r) => r,
306 Err(_) => {
307 if skip_timeout {
308 return Ok((1, vec![]));
309 } else {
310 bail!(
311 "[UNEXPECTED ERROR] Query timeout after {timeout_duration}s:\n{:?}",
312 query
313 )
314 }
315 }
316 };
317 if let Err(e) = &response
318 && let Some(e) = e.as_db_error()
319 {
320 if is_recovery_in_progress_error(&e.to_string()) {
321 let tries = 5;
322 let interval = 1;
323 for _ in 0..tries {
324 sleep(Duration::from_secs(interval)).await;
326 let query_task = client.simple_query(query);
327 let response = timeout(Duration::from_secs(timeout_duration), query_task).await;
328 match response {
329 Ok(Ok(r)) => {
330 return Ok((0, r));
331 }
332 Err(_) => bail!(
333 "[UNEXPECTED ERROR] Query timeout after {timeout_duration}s:\n{:?}",
334 query
335 ),
336 _ => {}
337 }
338 }
339 bail!(
340 "[UNEXPECTED ERROR] Failed to recover after {tries} tries with interval {interval}s"
341 )
342 } else {
343 return validate_response(response);
344 }
345 }
346 let rows = response?;
347 Ok((0, rows))
348}
349
350pub(super) fn generate_rng(seed: Option<u64>) -> impl Rng {
351 #[cfg(madsim)]
352 if let Some(seed) = seed {
353 ChaChaRng::seed_from_u64(seed)
354 } else {
355 ChaChaRng::from_rng(&mut SmallRng::from_os_rng())
356 }
357 #[cfg(not(madsim))]
358 if let Some(seed) = seed {
359 SmallRng::seed_from_u64(seed)
360 } else {
361 SmallRng::from_os_rng()
362 }
363}