risingwave_sqlsmith/
reducer.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Provides E2E Test runner functionality.
16use std::collections::HashSet;
17use std::fmt::Write;
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use regex::Regex;
22use risingwave_sqlparser::ast::{
23    Cte, CteInner, Expr, FunctionArgExpr, Join, Query, Select, SetExpr, Statement, TableFactor,
24    TableWithJoins, With,
25};
26
27use crate::parse_sql;
28use crate::utils::{create_file, read_file_contents, write_to_file};
29
30type Result<A> = anyhow::Result<A>;
31
32/// Shrinks a given failing query file.
33pub fn shrink_file(input_file_path: &str, output_file_path: &str) -> Result<()> {
34    // read failed sql
35    let file_contents = read_file_contents(input_file_path)?;
36
37    // reduce failed sql
38    let reduced_sql = shrink(&file_contents)?;
39
40    // write reduced sql
41    let mut file = create_file(output_file_path).unwrap();
42    write_to_file(&mut file, reduced_sql)
43}
44
45fn shrink(sql: &str) -> Result<String> {
46    let sql_statements = parse_sql(sql);
47
48    // Session variable before the failing query.
49    let session_variable = sql_statements
50        .get(sql_statements.len() - 2)
51        .filter(|statement| matches!(statement, Statement::SetVariable { .. }));
52
53    let failing_query = sql_statements
54        .last()
55        .ok_or_else(|| anyhow!("Could not get last sql statement"))?;
56
57    let ddl_references = find_ddl_references(&sql_statements);
58
59    tracing::info!("[DDL REFERENCES]: {}", ddl_references.iter().join(", "));
60
61    let mut ddl = sql_statements
62        .iter()
63        .filter(|s| {
64            matches!(*s,
65            Statement::CreateView { name, .. } | Statement::CreateTable { name, .. }
66                if ddl_references.contains(&name.real_value()))
67        })
68        .collect();
69
70    let mut dml = sql_statements
71        .iter()
72        .filter(|s| {
73            matches!(*s,
74            Statement::Insert { table_name, .. }
75                if ddl_references.contains(&table_name.real_value()))
76        })
77        .collect();
78
79    let mut reduced_statements = vec![];
80    reduced_statements.append(&mut ddl);
81    reduced_statements.append(&mut dml);
82    if let Some(session_variable) = session_variable {
83        reduced_statements.push(session_variable);
84    }
85    reduced_statements.push(failing_query);
86
87    let sql = reduced_statements
88        .iter()
89        .fold(String::new(), |mut output, s| {
90            let _ = writeln!(output, "{s};");
91            output
92        });
93
94    Ok(sql)
95}
96
97pub(crate) fn find_ddl_references(sql_statements: &[Statement]) -> HashSet<String> {
98    let mut ddl_references = HashSet::new();
99    let mut sql_statements = sql_statements.iter().rev();
100    let failing = sql_statements.next().unwrap();
101    match failing {
102        Statement::Query(query) | Statement::CreateView { query, .. } => {
103            find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
104        }
105        _ => {}
106    };
107    for sql_statement in sql_statements {
108        match sql_statement {
109            Statement::Query(query) => {
110                find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
111            }
112            Statement::CreateView { query, name, .. }
113                if ddl_references.contains(&name.real_value()) =>
114            {
115                find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
116            }
117            _ => {}
118        };
119    }
120    ddl_references
121}
122
123pub(crate) fn find_ddl_references_for_query(query: &Query, ddl_references: &mut HashSet<String>) {
124    let Query { with, body, .. } = query;
125    if let Some(With { cte_tables, .. }) = with {
126        for Cte { cte_inner, .. } in cte_tables {
127            if let CteInner::Query(query) = cte_inner {
128                find_ddl_references_for_query(query, ddl_references)
129            }
130        }
131    }
132    find_ddl_references_for_query_in_set_expr(body, ddl_references);
133}
134
135fn find_ddl_references_for_query_in_set_expr(
136    set_expr: &SetExpr,
137    ddl_references: &mut HashSet<String>,
138) {
139    match set_expr {
140        SetExpr::Select(box Select { from, .. }) => {
141            for table_with_joins in from {
142                find_ddl_references_for_query_in_table_with_joins(table_with_joins, ddl_references);
143            }
144        }
145        SetExpr::Query(q) => find_ddl_references_for_query(q, ddl_references),
146        SetExpr::SetOperation { left, right, .. } => {
147            find_ddl_references_for_query_in_set_expr(left, ddl_references);
148            find_ddl_references_for_query_in_set_expr(right, ddl_references);
149        }
150        SetExpr::Values(_) => {}
151    }
152}
153
154fn find_ddl_references_for_query_in_table_with_joins(
155    TableWithJoins { relation, joins }: &TableWithJoins,
156    ddl_references: &mut HashSet<String>,
157) {
158    find_ddl_references_for_query_in_table_factor(relation, ddl_references);
159    for Join { relation, .. } in joins {
160        find_ddl_references_for_query_in_table_factor(relation, ddl_references);
161    }
162}
163
164fn find_ddl_references_for_query_in_table_factor(
165    table_factor: &TableFactor,
166    ddl_references: &mut HashSet<String>,
167) {
168    match table_factor {
169        TableFactor::Table { name, .. } => {
170            ddl_references.insert(name.real_value());
171        }
172        TableFactor::Derived { subquery, .. } => {
173            find_ddl_references_for_query(subquery, ddl_references)
174        }
175        TableFactor::TableFunction { name, args, .. } => {
176            let name = name.real_value();
177            // https://docs.rs/regex/latest/regex/#grouping-and-flags
178            let regex = Regex::new(r"(?i)(tumble|hop)").unwrap();
179            if regex.is_match(&name) && args.len() >= 3 {
180                let table_name = &args[0];
181                if let FunctionArgExpr::Expr(Expr::Identifier(table_name)) = table_name.get_expr() {
182                    ddl_references.insert(table_name.to_string().to_lowercase());
183                }
184            }
185        }
186        TableFactor::NestedJoin(table_with_joins) => {
187            find_ddl_references_for_query_in_table_with_joins(table_with_joins, ddl_references);
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    const DDL_AND_DML: &str = "
197CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
198CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
199CREATE TABLE T3 (V1 timestamp, V2 INT, V3 INT);
200CREATE MATERIALIZED VIEW M1 AS SELECT * FROM T1;
201CREATE MATERIALIZED VIEW M2 AS SELECT * FROM T2 LEFT JOIN T3 ON T2.V1 = T3.V2;
202CREATE MATERIALIZED VIEW M3 AS SELECT * FROM T1 LEFT JOIN T2;
203CREATE MATERIALIZED VIEW M4 AS SELECT * FROM M3;
204INSERT INTO T1 VALUES(0, 0, 1);
205INSERT INTO T1 VALUES(0, 0, 2);
206INSERT INTO T2 VALUES(0, 0, 3);
207INSERT INTO T2 VALUES(0, 0, 4);
208INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
209INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
210SET RW_TWO_PHASE_AGG=TRUE;
211    ";
212
213    fn sql_to_query(sql: &str) -> Box<Query> {
214        let sql_statement = parse_sql(sql).into_iter().next().unwrap();
215        match sql_statement {
216            Statement::Query(query) | Statement::CreateView { query, .. } => query,
217            _ => panic!("Last statement was not a query, can't shrink"),
218        }
219    }
220
221    #[test]
222    fn test_find_ddl_references_for_query_simple() {
223        let sql = "SELECT * FROM T1;";
224        let query = sql_to_query(sql);
225        let mut ddl_references = HashSet::new();
226        find_ddl_references_for_query(&query, &mut ddl_references);
227        println!("{:#?}", ddl_references);
228        assert!(ddl_references.contains("t1"));
229    }
230
231    #[test]
232    fn test_find_ddl_references_for_tumble() {
233        let sql = "SELECT * FROM TUMBLE(T3, V1, INTERVAL '3' DAY);";
234        let query = sql_to_query(sql);
235        let mut ddl_references = HashSet::new();
236        find_ddl_references_for_query(&query, &mut ddl_references);
237        println!("{:#?}", ddl_references);
238        assert!(ddl_references.contains("t3"));
239    }
240
241    #[test]
242    fn test_find_ddl_references_for_query_with_cte() {
243        let sql = "WITH WITH0 AS (SELECT * FROM M3) SELECT * FROM WITH0";
244        let sql_statements = DDL_AND_DML.to_owned() + sql;
245        let sql_statements = parse_sql(sql_statements);
246        let ddl_references = find_ddl_references(&sql_statements);
247        assert!(ddl_references.contains("m3"));
248        assert!(ddl_references.contains("t1"));
249        assert!(ddl_references.contains("t2"));
250
251        assert!(!ddl_references.contains("m4"));
252        assert!(!ddl_references.contains("t3"));
253        assert!(!ddl_references.contains("m1"));
254        assert!(!ddl_references.contains("m2"));
255    }
256
257    #[test]
258    fn test_find_ddl_references_for_query_with_mv_on_mv() {
259        let sql = "WITH WITH0 AS (SELECT * FROM M4) SELECT * FROM WITH0";
260        let sql_statements = DDL_AND_DML.to_owned() + sql;
261        let sql_statements = parse_sql(sql_statements);
262        let ddl_references = find_ddl_references(&sql_statements);
263        assert!(ddl_references.contains("m4"));
264        assert!(ddl_references.contains("m3"));
265        assert!(ddl_references.contains("t1"));
266        assert!(ddl_references.contains("t2"));
267
268        assert!(!ddl_references.contains("t3"));
269        assert!(!ddl_references.contains("m1"));
270        assert!(!ddl_references.contains("m2"));
271    }
272
273    #[test]
274    fn test_find_ddl_references_for_query_joins() {
275        let sql = "SELECT * FROM (T1 JOIN T2 ON T1.V1 = T2.V2) JOIN T3 ON T2.V1 = T3.V2";
276        let sql_statements = DDL_AND_DML.to_owned() + sql;
277        let sql_statements = parse_sql(sql_statements);
278        let ddl_references = find_ddl_references(&sql_statements);
279        assert!(ddl_references.contains("t1"));
280        assert!(ddl_references.contains("t2"));
281        assert!(ddl_references.contains("t3"));
282
283        assert!(!ddl_references.contains("m1"));
284        assert!(!ddl_references.contains("m2"));
285        assert!(!ddl_references.contains("m3"));
286        assert!(!ddl_references.contains("m4"));
287    }
288
289    #[test]
290    fn test_shrink_values() {
291        let query = "SELECT 1;";
292        let sql = DDL_AND_DML.to_owned() + query;
293        let expected = format!(
294            "\
295SET RW_TWO_PHASE_AGG = true;
296{query}
297"
298        );
299        assert_eq!(expected, shrink(&sql).unwrap());
300    }
301
302    #[test]
303    fn test_shrink_simple_table() {
304        let query = "SELECT * FROM t1;";
305        let sql = DDL_AND_DML.to_owned() + query;
306        let expected = format!(
307            "\
308CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
309INSERT INTO T1 VALUES (0, 0, 1);
310INSERT INTO T1 VALUES (0, 0, 2);
311SET RW_TWO_PHASE_AGG = true;
312{query}
313"
314        );
315        assert_eq!(expected, shrink(&sql).unwrap());
316    }
317
318    #[test]
319    fn test_shrink_simple_table_with_alias() {
320        let query = "SELECT * FROM t1 AS s1;";
321        let sql = DDL_AND_DML.to_owned() + query;
322        let expected = format!(
323            "\
324CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
325INSERT INTO T1 VALUES (0, 0, 1);
326INSERT INTO T1 VALUES (0, 0, 2);
327SET RW_TWO_PHASE_AGG = true;
328{query}
329"
330        );
331        assert_eq!(expected, shrink(&sql).unwrap());
332    }
333
334    #[test]
335    fn test_shrink_join() {
336        let query = "SELECT * FROM (T1 JOIN T2 ON T1.V1 = T2.V2) JOIN T3 ON T2.V1 = T3.V2;";
337        let sql = DDL_AND_DML.to_owned() + query;
338        let expected = format!(
339            "\
340CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
341CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
342CREATE TABLE T3 (V1 TIMESTAMP, V2 INT, V3 INT);
343INSERT INTO T1 VALUES (0, 0, 1);
344INSERT INTO T1 VALUES (0, 0, 2);
345INSERT INTO T2 VALUES (0, 0, 3);
346INSERT INTO T2 VALUES (0, 0, 4);
347INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
348INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
349SET RW_TWO_PHASE_AGG = true;
350{query}
351"
352        );
353        assert_eq!(expected, shrink(&sql).unwrap());
354    }
355
356    #[test]
357    fn test_shrink_tumble() {
358        let query = "SELECT * FROM TUMBLE(T3, V1, INTERVAL '3' DAY);";
359        let sql = DDL_AND_DML.to_owned() + query;
360        let expected = format!(
361            "\
362CREATE TABLE T3 (V1 TIMESTAMP, V2 INT, V3 INT);
363INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
364INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
365SET RW_TWO_PHASE_AGG = true;
366{query}
367"
368        );
369        assert_eq!(expected, shrink(&sql).unwrap());
370    }
371
372    #[test]
373    fn test_shrink_subquery() {
374        let query = "SELECT * FROM (SELECT V1 AS K1 FROM T2);";
375        let sql = DDL_AND_DML.to_owned() + query;
376        let expected = format!(
377            "\
378CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
379INSERT INTO T2 VALUES (0, 0, 3);
380INSERT INTO T2 VALUES (0, 0, 4);
381SET RW_TWO_PHASE_AGG = true;
382{query}
383"
384        );
385        assert_eq!(expected, shrink(&sql).unwrap());
386    }
387
388    #[test]
389    fn test_shrink_mview() {
390        let query = "CREATE MATERIALIZED VIEW m5 AS SELECT * FROM (SELECT V1 AS K1 FROM T2);";
391        let sql = DDL_AND_DML.to_owned() + query;
392        let expected = format!(
393            "\
394CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
395INSERT INTO T2 VALUES (0, 0, 3);
396INSERT INTO T2 VALUES (0, 0, 4);
397SET RW_TWO_PHASE_AGG = true;
398{query}
399"
400        );
401        assert_eq!(expected, shrink(&sql).unwrap());
402    }
403}