risingwave_sqlsmith/
reducer.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Provides E2E Test runner functionality.
16use std::collections::HashSet;
17use std::fmt::Write;
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use regex::Regex;
22use risingwave_sqlparser::ast::{
23    Cte, CteInner, Expr, FunctionArgExpr, Join, Query, Select, SetExpr, Statement, TableFactor,
24    TableWithJoins, With,
25};
26use tokio_postgres::Client;
27
28use crate::parse_sql;
29use crate::sqlreduce::Strategy;
30use crate::sqlreduce::checker::Checker;
31use crate::sqlreduce::reducer::Reducer;
32use crate::utils::{create_file, read_file_contents, write_to_file};
33
34type Result<A> = anyhow::Result<A>;
35
36/// Shrinks a given failing query file.
37pub async fn shrink_file(
38    input_file_path: &str,
39    output_file_path: &str,
40    strategy: Strategy,
41    client: Client,
42    restore_cmd: &str,
43) -> Result<()> {
44    // read failed sql
45    let file_contents = read_file_contents(input_file_path)?;
46
47    // reduce failed sql
48    let reduced_sql = shrink_statements(&file_contents)?;
49
50    // shrink the reduced sql
51    let reduced_sql = shrink(&reduced_sql, strategy, client, restore_cmd).await?;
52
53    // write reduced sql
54    let mut file = create_file(output_file_path).unwrap();
55    write_to_file(&mut file, reduced_sql)
56}
57
58fn shrink_statements(sql: &str) -> Result<String> {
59    let sql_statements = parse_sql(sql);
60
61    // Session variable before the failing query.
62    let session_variable = sql_statements
63        .get(sql_statements.len() - 2)
64        .filter(|statement| matches!(statement, Statement::SetVariable { .. }));
65
66    let failing_query = sql_statements
67        .last()
68        .ok_or_else(|| anyhow!("Could not get last sql statement"))?;
69
70    let ddl_references = find_ddl_references(&sql_statements);
71
72    tracing::info!("[DDL REFERENCES]: {}", ddl_references.iter().join(", "));
73
74    let mut ddl = sql_statements
75        .iter()
76        .filter(|s| {
77            matches!(*s,
78            Statement::CreateView { name, .. } | Statement::CreateTable { name, .. }
79                if ddl_references.contains(&name.real_value()))
80        })
81        .collect();
82
83    let mut dml = sql_statements
84        .iter()
85        .filter(|s| {
86            matches!(*s,
87            Statement::Insert { table_name, .. }
88                if ddl_references.contains(&table_name.real_value()))
89        })
90        .collect();
91
92    let mut reduced_statements = vec![];
93    reduced_statements.append(&mut ddl);
94    reduced_statements.append(&mut dml);
95    if let Some(session_variable) = session_variable {
96        reduced_statements.push(session_variable);
97    }
98    reduced_statements.push(failing_query);
99
100    let sql = reduced_statements
101        .iter()
102        .fold(String::new(), |mut output, s| {
103            let _ = writeln!(output, "{s};");
104            output
105        });
106
107    Ok(sql)
108}
109
110async fn shrink(
111    sql: &str,
112    strategy: Strategy,
113    client: Client,
114    restore_cmd: &str,
115) -> Result<String> {
116    let sql_statements = parse_sql(sql);
117    let proceeding_stmts = sql_statements.split_last().unwrap().1.to_vec();
118    let checker = Checker::new(client, proceeding_stmts, restore_cmd.to_owned());
119    let mut reducer = Reducer::new(checker, strategy);
120
121    let reduced_sql = reducer.reduce(sql).await?;
122
123    Ok(reduced_sql)
124}
125
126pub(crate) fn find_ddl_references(sql_statements: &[Statement]) -> HashSet<String> {
127    let mut ddl_references = HashSet::new();
128    let mut sql_statements = sql_statements.iter().rev();
129    let failing = sql_statements.next().unwrap();
130    match failing {
131        Statement::Query(query) | Statement::CreateView { query, .. } => {
132            find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
133        }
134        _ => {}
135    };
136    for sql_statement in sql_statements {
137        match sql_statement {
138            Statement::Query(query) => {
139                find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
140            }
141            Statement::CreateView { query, name, .. }
142                if ddl_references.contains(&name.real_value()) =>
143            {
144                find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
145            }
146            _ => {}
147        };
148    }
149    ddl_references
150}
151
152pub(crate) fn find_ddl_references_for_query(query: &Query, ddl_references: &mut HashSet<String>) {
153    let Query { with, body, .. } = query;
154    if let Some(With { cte_tables, .. }) = with {
155        for Cte { cte_inner, .. } in cte_tables {
156            if let CteInner::Query(query) = cte_inner {
157                find_ddl_references_for_query(query, ddl_references)
158            }
159        }
160    }
161    find_ddl_references_for_query_in_set_expr(body, ddl_references);
162}
163
164fn find_ddl_references_for_query_in_set_expr(
165    set_expr: &SetExpr,
166    ddl_references: &mut HashSet<String>,
167) {
168    match set_expr {
169        SetExpr::Select(box Select { from, .. }) => {
170            for table_with_joins in from {
171                find_ddl_references_for_query_in_table_with_joins(table_with_joins, ddl_references);
172            }
173        }
174        SetExpr::Query(q) => find_ddl_references_for_query(q, ddl_references),
175        SetExpr::SetOperation { left, right, .. } => {
176            find_ddl_references_for_query_in_set_expr(left, ddl_references);
177            find_ddl_references_for_query_in_set_expr(right, ddl_references);
178        }
179        SetExpr::Values(_) => {}
180    }
181}
182
183fn find_ddl_references_for_query_in_table_with_joins(
184    TableWithJoins { relation, joins }: &TableWithJoins,
185    ddl_references: &mut HashSet<String>,
186) {
187    find_ddl_references_for_query_in_table_factor(relation, ddl_references);
188    for Join { relation, .. } in joins {
189        find_ddl_references_for_query_in_table_factor(relation, ddl_references);
190    }
191}
192
193fn find_ddl_references_for_query_in_table_factor(
194    table_factor: &TableFactor,
195    ddl_references: &mut HashSet<String>,
196) {
197    match table_factor {
198        TableFactor::Table { name, .. } => {
199            ddl_references.insert(name.real_value());
200        }
201        TableFactor::Derived { subquery, .. } => {
202            find_ddl_references_for_query(subquery, ddl_references)
203        }
204        TableFactor::TableFunction { name, args, .. } => {
205            let name = name.real_value();
206            // https://docs.rs/regex/latest/regex/#grouping-and-flags
207            let regex = Regex::new(r"(?i)(tumble|hop)").unwrap();
208            if regex.is_match(&name) && args.len() >= 3 {
209                let table_name = &args[0];
210                if let FunctionArgExpr::Expr(Expr::Identifier(table_name)) = table_name.get_expr() {
211                    ddl_references.insert(table_name.to_string().to_lowercase());
212                }
213            }
214        }
215        TableFactor::NestedJoin(table_with_joins) => {
216            find_ddl_references_for_query_in_table_with_joins(table_with_joins, ddl_references);
217        }
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    const DDL_AND_DML: &str = "
226CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
227CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
228CREATE TABLE T3 (V1 timestamp, V2 INT, V3 INT);
229CREATE MATERIALIZED VIEW M1 AS SELECT * FROM T1;
230CREATE MATERIALIZED VIEW M2 AS SELECT * FROM T2 LEFT JOIN T3 ON T2.V1 = T3.V2;
231CREATE MATERIALIZED VIEW M3 AS SELECT * FROM T1 LEFT JOIN T2;
232CREATE MATERIALIZED VIEW M4 AS SELECT * FROM M3;
233INSERT INTO T1 VALUES(0, 0, 1);
234INSERT INTO T1 VALUES(0, 0, 2);
235INSERT INTO T2 VALUES(0, 0, 3);
236INSERT INTO T2 VALUES(0, 0, 4);
237INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
238INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
239SET RW_TWO_PHASE_AGG=TRUE;
240    ";
241
242    fn sql_to_query(sql: &str) -> Box<Query> {
243        let sql_statement = parse_sql(sql).into_iter().next().unwrap();
244        match sql_statement {
245            Statement::Query(query) | Statement::CreateView { query, .. } => query,
246            _ => panic!("Last statement was not a query, can't shrink"),
247        }
248    }
249
250    #[test]
251    fn test_find_ddl_references_for_query_simple() {
252        let sql = "SELECT * FROM T1;";
253        let query = sql_to_query(sql);
254        let mut ddl_references = HashSet::new();
255        find_ddl_references_for_query(&query, &mut ddl_references);
256        println!("{:#?}", ddl_references);
257        assert!(ddl_references.contains("t1"));
258    }
259
260    #[test]
261    fn test_find_ddl_references_for_tumble() {
262        let sql = "SELECT * FROM TUMBLE(T3, V1, INTERVAL '3' DAY);";
263        let query = sql_to_query(sql);
264        let mut ddl_references = HashSet::new();
265        find_ddl_references_for_query(&query, &mut ddl_references);
266        println!("{:#?}", ddl_references);
267        assert!(ddl_references.contains("t3"));
268    }
269
270    #[test]
271    fn test_find_ddl_references_for_query_with_cte() {
272        let sql = "WITH WITH0 AS (SELECT * FROM M3) SELECT * FROM WITH0";
273        let sql_statements = DDL_AND_DML.to_owned() + sql;
274        let sql_statements = parse_sql(sql_statements);
275        let ddl_references = find_ddl_references(&sql_statements);
276        assert!(ddl_references.contains("m3"));
277        assert!(ddl_references.contains("t1"));
278        assert!(ddl_references.contains("t2"));
279
280        assert!(!ddl_references.contains("m4"));
281        assert!(!ddl_references.contains("t3"));
282        assert!(!ddl_references.contains("m1"));
283        assert!(!ddl_references.contains("m2"));
284    }
285
286    #[test]
287    fn test_find_ddl_references_for_query_with_mv_on_mv() {
288        let sql = "WITH WITH0 AS (SELECT * FROM M4) SELECT * FROM WITH0";
289        let sql_statements = DDL_AND_DML.to_owned() + sql;
290        let sql_statements = parse_sql(sql_statements);
291        let ddl_references = find_ddl_references(&sql_statements);
292        assert!(ddl_references.contains("m4"));
293        assert!(ddl_references.contains("m3"));
294        assert!(ddl_references.contains("t1"));
295        assert!(ddl_references.contains("t2"));
296
297        assert!(!ddl_references.contains("t3"));
298        assert!(!ddl_references.contains("m1"));
299        assert!(!ddl_references.contains("m2"));
300    }
301
302    #[test]
303    fn test_find_ddl_references_for_query_joins() {
304        let sql = "SELECT * FROM (T1 JOIN T2 ON T1.V1 = T2.V2) JOIN T3 ON T2.V1 = T3.V2";
305        let sql_statements = DDL_AND_DML.to_owned() + sql;
306        let sql_statements = parse_sql(sql_statements);
307        let ddl_references = find_ddl_references(&sql_statements);
308        assert!(ddl_references.contains("t1"));
309        assert!(ddl_references.contains("t2"));
310        assert!(ddl_references.contains("t3"));
311
312        assert!(!ddl_references.contains("m1"));
313        assert!(!ddl_references.contains("m2"));
314        assert!(!ddl_references.contains("m3"));
315        assert!(!ddl_references.contains("m4"));
316    }
317
318    #[test]
319    fn test_shrink_values() {
320        let query = "SELECT 1;";
321        let sql = DDL_AND_DML.to_owned() + query;
322        let expected = format!(
323            "\
324SET RW_TWO_PHASE_AGG = true;
325{query}
326"
327        );
328        assert_eq!(expected, shrink_statements(&sql).unwrap());
329    }
330
331    #[test]
332    fn test_shrink_simple_table() {
333        let query = "SELECT * FROM t1;";
334        let sql = DDL_AND_DML.to_owned() + query;
335        let expected = format!(
336            "\
337CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
338INSERT INTO T1 VALUES (0, 0, 1);
339INSERT INTO T1 VALUES (0, 0, 2);
340SET RW_TWO_PHASE_AGG = true;
341{query}
342"
343        );
344        assert_eq!(expected, shrink_statements(&sql).unwrap());
345    }
346
347    #[test]
348    fn test_shrink_simple_table_with_alias() {
349        let query = "SELECT * FROM t1 AS s1;";
350        let sql = DDL_AND_DML.to_owned() + query;
351        let expected = format!(
352            "\
353CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
354INSERT INTO T1 VALUES (0, 0, 1);
355INSERT INTO T1 VALUES (0, 0, 2);
356SET RW_TWO_PHASE_AGG = true;
357{query}
358"
359        );
360        assert_eq!(expected, shrink_statements(&sql).unwrap());
361    }
362
363    #[test]
364    fn test_shrink_join() {
365        let query = "SELECT * FROM (T1 JOIN T2 ON T1.V1 = T2.V2) JOIN T3 ON T2.V1 = T3.V2;";
366        let sql = DDL_AND_DML.to_owned() + query;
367        let expected = format!(
368            "\
369CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
370CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
371CREATE TABLE T3 (V1 TIMESTAMP, V2 INT, V3 INT);
372INSERT INTO T1 VALUES (0, 0, 1);
373INSERT INTO T1 VALUES (0, 0, 2);
374INSERT INTO T2 VALUES (0, 0, 3);
375INSERT INTO T2 VALUES (0, 0, 4);
376INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
377INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
378SET RW_TWO_PHASE_AGG = true;
379{query}
380"
381        );
382        assert_eq!(expected, shrink_statements(&sql).unwrap());
383    }
384
385    #[test]
386    fn test_shrink_tumble() {
387        let query = "SELECT * FROM TUMBLE(T3, V1, INTERVAL '3' DAY);";
388        let sql = DDL_AND_DML.to_owned() + query;
389        let expected = format!(
390            "\
391CREATE TABLE T3 (V1 TIMESTAMP, V2 INT, V3 INT);
392INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
393INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
394SET RW_TWO_PHASE_AGG = true;
395{query}
396"
397        );
398        assert_eq!(expected, shrink_statements(&sql).unwrap());
399    }
400
401    #[test]
402    fn test_shrink_subquery() {
403        let query = "SELECT * FROM (SELECT V1 AS K1 FROM T2);";
404        let sql = DDL_AND_DML.to_owned() + query;
405        let expected = format!(
406            "\
407CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
408INSERT INTO T2 VALUES (0, 0, 3);
409INSERT INTO T2 VALUES (0, 0, 4);
410SET RW_TWO_PHASE_AGG = true;
411{query}
412"
413        );
414        assert_eq!(expected, shrink_statements(&sql).unwrap());
415    }
416
417    #[test]
418    fn test_shrink_mview() {
419        let query = "CREATE MATERIALIZED VIEW m5 AS SELECT * FROM (SELECT V1 AS K1 FROM T2);";
420        let sql = DDL_AND_DML.to_owned() + query;
421        let expected = format!(
422            "\
423CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
424INSERT INTO T2 VALUES (0, 0, 3);
425INSERT INTO T2 VALUES (0, 0, 4);
426SET RW_TWO_PHASE_AGG = true;
427{query}
428"
429        );
430        assert_eq!(expected, shrink_statements(&sql).unwrap());
431    }
432}