risingwave_sqlsmith/
reducer.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Provides E2E Test runner functionality.
16use std::collections::HashSet;
17use std::fmt::Write;
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use regex::Regex;
22use risingwave_sqlparser::ast::{
23    Cte, CteInner, Expr, FunctionArgExpr, Join, Query, Select, SetExpr, Statement, TableFactor,
24    TableWithJoins, With,
25};
26use tokio_postgres::Client;
27
28use crate::parse_sql;
29use crate::sqlreduce::checker::Checker;
30use crate::sqlreduce::reducer::Reducer;
31use crate::utils::{create_file, read_file_contents, write_to_file};
32
33type Result<A> = anyhow::Result<A>;
34
35/// Shrinks a given failing query file.
36pub async fn shrink_file(
37    input_file_path: &str,
38    output_file_path: &str,
39    client: Client,
40    restore_cmd: &str,
41) -> Result<()> {
42    // read failed sql
43    let file_contents = read_file_contents(input_file_path)?;
44
45    // reduce failed sql
46    let reduced_sql = shrink_statements(&file_contents)?;
47
48    // shrink the reduced sql
49    let reduced_sql = shrink_with_reducer(&reduced_sql, client, restore_cmd).await?;
50
51    // write reduced sql
52    let mut file = create_file(output_file_path).unwrap();
53    write_to_file(&mut file, reduced_sql)
54}
55
56fn shrink_statements(sql: &str) -> Result<String> {
57    let sql_statements = parse_sql(sql);
58
59    // Session variable before the failing query.
60    let session_variable = sql_statements
61        .get(sql_statements.len() - 2)
62        .filter(|statement| matches!(statement, Statement::SetVariable { .. }));
63
64    let failing_query = sql_statements
65        .last()
66        .ok_or_else(|| anyhow!("Could not get last sql statement"))?;
67
68    let ddl_references = find_ddl_references(&sql_statements);
69
70    tracing::info!("[DDL REFERENCES]: {}", ddl_references.iter().join(", "));
71
72    let mut ddl = sql_statements
73        .iter()
74        .filter(|s| {
75            matches!(*s,
76            Statement::CreateView { name, .. } | Statement::CreateTable { name, .. }
77                if ddl_references.contains(&name.real_value()))
78        })
79        .collect();
80
81    let mut dml = sql_statements
82        .iter()
83        .filter(|s| {
84            matches!(*s,
85            Statement::Insert { table_name, .. }
86                if ddl_references.contains(&table_name.real_value()))
87        })
88        .collect();
89
90    let mut reduced_statements = vec![];
91    reduced_statements.append(&mut ddl);
92    reduced_statements.append(&mut dml);
93    if let Some(session_variable) = session_variable {
94        reduced_statements.push(session_variable);
95    }
96    reduced_statements.push(failing_query);
97
98    let sql = reduced_statements
99        .iter()
100        .fold(String::new(), |mut output, s| {
101            let _ = writeln!(output, "{s};");
102            output
103        });
104
105    Ok(sql)
106}
107
108/// Shrink function using path-based reduction
109async fn shrink_with_reducer(sql: &str, client: Client, restore_cmd: &str) -> Result<String> {
110    let sql_statements = parse_sql(sql);
111    let proceeding_stmts = sql_statements.split_last().unwrap().1.to_vec();
112    let checker = Checker::new(client, proceeding_stmts, restore_cmd.to_owned());
113    let mut reducer = Reducer::new(checker);
114
115    let reduced_sql = reducer.reduce(sql).await?;
116
117    Ok(reduced_sql)
118}
119
120pub(crate) fn find_ddl_references(sql_statements: &[Statement]) -> HashSet<String> {
121    let mut ddl_references = HashSet::new();
122    let mut sql_statements = sql_statements.iter().rev();
123    let failing = sql_statements.next().unwrap();
124    match failing {
125        Statement::Query(query) | Statement::CreateView { query, .. } => {
126            find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
127        }
128        _ => {}
129    };
130    for sql_statement in sql_statements {
131        match sql_statement {
132            Statement::Query(query) => {
133                find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
134            }
135            Statement::CreateView { query, name, .. }
136                if ddl_references.contains(&name.real_value()) =>
137            {
138                find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
139            }
140            _ => {}
141        };
142    }
143    ddl_references
144}
145
146pub(crate) fn find_ddl_references_for_query(query: &Query, ddl_references: &mut HashSet<String>) {
147    let Query { with, body, .. } = query;
148    if let Some(With { cte_tables, .. }) = with {
149        for Cte { cte_inner, .. } in cte_tables {
150            if let CteInner::Query(query) = cte_inner {
151                find_ddl_references_for_query(query, ddl_references)
152            }
153        }
154    }
155    find_ddl_references_for_query_in_set_expr(body, ddl_references);
156}
157
158fn find_ddl_references_for_query_in_set_expr(
159    set_expr: &SetExpr,
160    ddl_references: &mut HashSet<String>,
161) {
162    match set_expr {
163        SetExpr::Select(box Select {
164            from,
165            selection,
166            projection,
167            group_by,
168            having,
169            ..
170        }) => {
171            // Scan FROM clause
172            for table_with_joins in from {
173                find_ddl_references_for_query_in_table_with_joins(table_with_joins, ddl_references);
174            }
175
176            // Scan WHERE clause (selection)
177            if let Some(where_expr) = selection {
178                find_ddl_references_in_expr(where_expr, ddl_references);
179            }
180
181            // Scan SELECT list (projection)
182            for select_item in projection {
183                match select_item {
184                    risingwave_sqlparser::ast::SelectItem::UnnamedExpr(expr) => {
185                        find_ddl_references_in_expr(expr, ddl_references);
186                    }
187                    risingwave_sqlparser::ast::SelectItem::ExprWithAlias { expr, .. } => {
188                        find_ddl_references_in_expr(expr, ddl_references);
189                    }
190                    _ => {}
191                }
192            }
193
194            // Scan GROUP BY clause
195            for group_by_expr in group_by {
196                find_ddl_references_in_expr(group_by_expr, ddl_references);
197            }
198
199            // Scan HAVING clause
200            if let Some(having_expr) = having {
201                find_ddl_references_in_expr(having_expr, ddl_references);
202            }
203        }
204        SetExpr::Query(q) => find_ddl_references_for_query(q, ddl_references),
205        SetExpr::SetOperation { left, right, .. } => {
206            find_ddl_references_for_query_in_set_expr(left, ddl_references);
207            find_ddl_references_for_query_in_set_expr(right, ddl_references);
208        }
209        SetExpr::Values(_) => {}
210    }
211}
212
213fn find_ddl_references_in_expr(expr: &Expr, ddl_references: &mut HashSet<String>) {
214    match expr {
215        // EXISTS subquery
216        Expr::Exists(subquery) => {
217            find_ddl_references_for_query(subquery, ddl_references);
218        }
219        // Scalar subquery
220        Expr::Subquery(subquery) => {
221            find_ddl_references_for_query(subquery, ddl_references);
222        }
223        // Binary operations (AND, OR, comparisons, etc.)
224        Expr::BinaryOp { left, right, .. } => {
225            find_ddl_references_in_expr(left, ddl_references);
226            find_ddl_references_in_expr(right, ddl_references);
227        }
228        // Unary operations
229        Expr::UnaryOp { expr, .. } => {
230            find_ddl_references_in_expr(expr, ddl_references);
231        }
232        // Function calls
233        Expr::Function(function) => {
234            for arg in &function.arg_list.args {
235                match arg {
236                    risingwave_sqlparser::ast::FunctionArg::Unnamed(func_arg_expr) => {
237                        if let risingwave_sqlparser::ast::FunctionArgExpr::Expr(expr) =
238                            func_arg_expr
239                        {
240                            find_ddl_references_in_expr(expr, ddl_references);
241                        }
242                    }
243                    risingwave_sqlparser::ast::FunctionArg::Named { arg, .. } => {
244                        if let risingwave_sqlparser::ast::FunctionArgExpr::Expr(expr) = arg {
245                            find_ddl_references_in_expr(expr, ddl_references);
246                        }
247                    }
248                }
249            }
250        }
251        // CASE expressions
252        Expr::Case {
253            operand,
254            conditions,
255            results,
256            else_result,
257            ..
258        } => {
259            if let Some(operand_expr) = operand {
260                find_ddl_references_in_expr(operand_expr, ddl_references);
261            }
262            for condition in conditions {
263                find_ddl_references_in_expr(condition, ddl_references);
264            }
265            for result in results {
266                find_ddl_references_in_expr(result, ddl_references);
267            }
268            if let Some(else_expr) = else_result {
269                find_ddl_references_in_expr(else_expr, ddl_references);
270            }
271        }
272        // Nested expressions
273        Expr::Nested(inner_expr) => {
274            find_ddl_references_in_expr(inner_expr, ddl_references);
275        }
276        // Array expressions
277        Expr::Array(array) => {
278            for expr in &array.elem {
279                find_ddl_references_in_expr(expr, ddl_references);
280            }
281        }
282        // Row expressions
283        Expr::Row(exprs) => {
284            for expr in exprs {
285                find_ddl_references_in_expr(expr, ddl_references);
286            }
287        }
288        // IN expressions
289        Expr::InList { expr, list, .. } => {
290            find_ddl_references_in_expr(expr, ddl_references);
291            for item in list {
292                find_ddl_references_in_expr(item, ddl_references);
293            }
294        }
295        Expr::InSubquery { expr, subquery, .. } => {
296            find_ddl_references_in_expr(expr, ddl_references);
297            find_ddl_references_for_query(subquery, ddl_references);
298        }
299        // BETWEEN expressions
300        Expr::Between {
301            expr, low, high, ..
302        } => {
303            find_ddl_references_in_expr(expr, ddl_references);
304            find_ddl_references_in_expr(low, ddl_references);
305            find_ddl_references_in_expr(high, ddl_references);
306        }
307        // CAST expressions
308        Expr::Cast { expr, .. } | Expr::TryCast { expr, .. } => {
309            find_ddl_references_in_expr(expr, ddl_references);
310        }
311        // String operations
312        Expr::Substring {
313            expr,
314            substring_from,
315            substring_for,
316            ..
317        } => {
318            find_ddl_references_in_expr(expr, ddl_references);
319            if let Some(from_expr) = substring_from {
320                find_ddl_references_in_expr(from_expr, ddl_references);
321            }
322            if let Some(for_expr) = substring_for {
323                find_ddl_references_in_expr(for_expr, ddl_references);
324            }
325        }
326        Expr::Trim {
327            expr, trim_where, ..
328        } => {
329            find_ddl_references_in_expr(expr, ddl_references);
330            if let Some(trim_where_field) = trim_where {
331                match trim_where_field {
332                    risingwave_sqlparser::ast::TrimWhereField::Both
333                    | risingwave_sqlparser::ast::TrimWhereField::Leading
334                    | risingwave_sqlparser::ast::TrimWhereField::Trailing => {
335                        // TrimWhereField variants don't contain expressions
336                    }
337                }
338            }
339        }
340        // Other expressions that don't contain subqueries
341        Expr::Identifier(_)
342        | Expr::Value(_)
343        | Expr::Parameter { .. }
344        | Expr::Position { .. }
345        | Expr::Extract { .. }
346        | Expr::IsNull(_)
347        | Expr::IsNotNull(_)
348        | Expr::IsDistinctFrom(_, _)
349        | Expr::IsNotDistinctFrom(_, _) => {
350            // These don't contain subqueries, so no need to recurse
351        }
352        // Handle any other expression types that might be added in the future
353        _ => {
354            // For unknown expression types, we don't recurse to avoid panics
355        }
356    }
357}
358
359fn find_ddl_references_for_query_in_table_with_joins(
360    TableWithJoins { relation, joins }: &TableWithJoins,
361    ddl_references: &mut HashSet<String>,
362) {
363    find_ddl_references_for_query_in_table_factor(relation, ddl_references);
364    for Join { relation, .. } in joins {
365        find_ddl_references_for_query_in_table_factor(relation, ddl_references);
366    }
367}
368
369fn find_ddl_references_for_query_in_table_factor(
370    table_factor: &TableFactor,
371    ddl_references: &mut HashSet<String>,
372) {
373    match table_factor {
374        TableFactor::Table { name, .. } => {
375            ddl_references.insert(name.real_value());
376        }
377        TableFactor::Derived { subquery, .. } => {
378            find_ddl_references_for_query(subquery, ddl_references)
379        }
380        TableFactor::TableFunction { name, args, .. } => {
381            let name = name.real_value();
382            // https://docs.rs/regex/latest/regex/#grouping-and-flags
383            let regex = Regex::new(r"(?i)(tumble|hop)").unwrap();
384            if regex.is_match(&name) && args.len() >= 3 {
385                let table_name = &args[0];
386                if let FunctionArgExpr::Expr(Expr::Identifier(table_name)) = table_name.get_expr() {
387                    ddl_references.insert(table_name.to_string().to_lowercase());
388                }
389            }
390        }
391        TableFactor::NestedJoin(table_with_joins) => {
392            find_ddl_references_for_query_in_table_with_joins(table_with_joins, ddl_references);
393        }
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    const DDL_AND_DML: &str = "
402CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
403CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
404CREATE TABLE T3 (V1 timestamp, V2 INT, V3 INT);
405CREATE MATERIALIZED VIEW M1 AS SELECT * FROM T1;
406CREATE MATERIALIZED VIEW M2 AS SELECT * FROM T2 LEFT JOIN T3 ON T2.V1 = T3.V2;
407CREATE MATERIALIZED VIEW M3 AS SELECT * FROM T1 LEFT JOIN T2;
408CREATE MATERIALIZED VIEW M4 AS SELECT * FROM M3;
409INSERT INTO T1 VALUES(0, 0, 1);
410INSERT INTO T1 VALUES(0, 0, 2);
411INSERT INTO T2 VALUES(0, 0, 3);
412INSERT INTO T2 VALUES(0, 0, 4);
413INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
414INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
415SET RW_TWO_PHASE_AGG=TRUE;
416    ";
417
418    fn sql_to_query(sql: &str) -> Box<Query> {
419        let sql_statement = parse_sql(sql).into_iter().next().unwrap();
420        match sql_statement {
421            Statement::Query(query) | Statement::CreateView { query, .. } => query,
422            _ => panic!("Last statement was not a query, can't shrink"),
423        }
424    }
425
426    #[test]
427    fn test_find_ddl_references_for_query_simple() {
428        let sql = "SELECT * FROM T1;";
429        let query = sql_to_query(sql);
430        let mut ddl_references = HashSet::new();
431        find_ddl_references_for_query(&query, &mut ddl_references);
432        println!("{:#?}", ddl_references);
433        assert!(ddl_references.contains("t1"));
434    }
435
436    #[test]
437    fn test_find_ddl_references_for_tumble() {
438        let sql = "SELECT * FROM TUMBLE(T3, V1, INTERVAL '3' DAY);";
439        let query = sql_to_query(sql);
440        let mut ddl_references = HashSet::new();
441        find_ddl_references_for_query(&query, &mut ddl_references);
442        println!("{:#?}", ddl_references);
443        assert!(ddl_references.contains("t3"));
444    }
445
446    #[test]
447    fn test_find_ddl_references_for_query_with_cte() {
448        let sql = "WITH WITH0 AS (SELECT * FROM M3) SELECT * FROM WITH0";
449        let sql_statements = DDL_AND_DML.to_owned() + sql;
450        let sql_statements = parse_sql(sql_statements);
451        let ddl_references = find_ddl_references(&sql_statements);
452        assert!(ddl_references.contains("m3"));
453        assert!(ddl_references.contains("t1"));
454        assert!(ddl_references.contains("t2"));
455
456        assert!(!ddl_references.contains("m4"));
457        assert!(!ddl_references.contains("t3"));
458        assert!(!ddl_references.contains("m1"));
459        assert!(!ddl_references.contains("m2"));
460    }
461
462    #[test]
463    fn test_find_ddl_references_for_query_with_mv_on_mv() {
464        let sql = "WITH WITH0 AS (SELECT * FROM M4) SELECT * FROM WITH0";
465        let sql_statements = DDL_AND_DML.to_owned() + sql;
466        let sql_statements = parse_sql(sql_statements);
467        let ddl_references = find_ddl_references(&sql_statements);
468        assert!(ddl_references.contains("m4"));
469        assert!(ddl_references.contains("m3"));
470        assert!(ddl_references.contains("t1"));
471        assert!(ddl_references.contains("t2"));
472
473        assert!(!ddl_references.contains("t3"));
474        assert!(!ddl_references.contains("m1"));
475        assert!(!ddl_references.contains("m2"));
476    }
477
478    #[test]
479    fn test_find_ddl_references_for_query_joins() {
480        let sql = "SELECT * FROM (T1 JOIN T2 ON T1.V1 = T2.V2) JOIN T3 ON T2.V1 = T3.V2";
481        let sql_statements = DDL_AND_DML.to_owned() + sql;
482        let sql_statements = parse_sql(sql_statements);
483        let ddl_references = find_ddl_references(&sql_statements);
484        assert!(ddl_references.contains("t1"));
485        assert!(ddl_references.contains("t2"));
486        assert!(ddl_references.contains("t3"));
487
488        assert!(!ddl_references.contains("m1"));
489        assert!(!ddl_references.contains("m2"));
490        assert!(!ddl_references.contains("m3"));
491        assert!(!ddl_references.contains("m4"));
492    }
493
494    #[test]
495    fn test_shrink_values() {
496        let query = "SELECT 1;";
497        let sql = DDL_AND_DML.to_owned() + query;
498        let expected = format!(
499            "\
500SET RW_TWO_PHASE_AGG = true;
501{query}
502"
503        );
504        assert_eq!(expected, shrink_statements(&sql).unwrap());
505    }
506
507    #[test]
508    fn test_shrink_simple_table() {
509        let query = "SELECT * FROM t1;";
510        let sql = DDL_AND_DML.to_owned() + query;
511        let expected = format!(
512            "\
513CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
514INSERT INTO T1 VALUES (0, 0, 1);
515INSERT INTO T1 VALUES (0, 0, 2);
516SET RW_TWO_PHASE_AGG = true;
517{query}
518"
519        );
520        assert_eq!(expected, shrink_statements(&sql).unwrap());
521    }
522
523    #[test]
524    fn test_shrink_simple_table_with_alias() {
525        let query = "SELECT * FROM t1 AS s1;";
526        let sql = DDL_AND_DML.to_owned() + query;
527        let expected = format!(
528            "\
529CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
530INSERT INTO T1 VALUES (0, 0, 1);
531INSERT INTO T1 VALUES (0, 0, 2);
532SET RW_TWO_PHASE_AGG = true;
533{query}
534"
535        );
536        assert_eq!(expected, shrink_statements(&sql).unwrap());
537    }
538
539    #[test]
540    fn test_shrink_join() {
541        let query = "SELECT * FROM (T1 JOIN T2 ON T1.V1 = T2.V2) JOIN T3 ON T2.V1 = T3.V2;";
542        let sql = DDL_AND_DML.to_owned() + query;
543        let expected = format!(
544            "\
545CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
546CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
547CREATE TABLE T3 (V1 TIMESTAMP, V2 INT, V3 INT);
548INSERT INTO T1 VALUES (0, 0, 1);
549INSERT INTO T1 VALUES (0, 0, 2);
550INSERT INTO T2 VALUES (0, 0, 3);
551INSERT INTO T2 VALUES (0, 0, 4);
552INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
553INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
554SET RW_TWO_PHASE_AGG = true;
555{query}
556"
557        );
558        assert_eq!(expected, shrink_statements(&sql).unwrap());
559    }
560
561    #[test]
562    fn test_shrink_tumble() {
563        let query = "SELECT * FROM TUMBLE(T3, V1, INTERVAL '3' DAY);";
564        let sql = DDL_AND_DML.to_owned() + query;
565        let expected = format!(
566            "\
567CREATE TABLE T3 (V1 TIMESTAMP, V2 INT, V3 INT);
568INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
569INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
570SET RW_TWO_PHASE_AGG = true;
571{query}
572"
573        );
574        assert_eq!(expected, shrink_statements(&sql).unwrap());
575    }
576
577    #[test]
578    fn test_shrink_subquery() {
579        let query = "SELECT * FROM (SELECT V1 AS K1 FROM T2);";
580        let sql = DDL_AND_DML.to_owned() + query;
581        let expected = format!(
582            "\
583CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
584INSERT INTO T2 VALUES (0, 0, 3);
585INSERT INTO T2 VALUES (0, 0, 4);
586SET RW_TWO_PHASE_AGG = true;
587{query}
588"
589        );
590        assert_eq!(expected, shrink_statements(&sql).unwrap());
591    }
592
593    #[test]
594    fn test_shrink_mview() {
595        let query = "CREATE MATERIALIZED VIEW m5 AS SELECT * FROM (SELECT V1 AS K1 FROM T2);";
596        let sql = DDL_AND_DML.to_owned() + query;
597        let expected = format!(
598            "\
599CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
600INSERT INTO T2 VALUES (0, 0, 3);
601INSERT INTO T2 VALUES (0, 0, 4);
602SET RW_TWO_PHASE_AGG = true;
603{query}
604"
605        );
606        assert_eq!(expected, shrink_statements(&sql).unwrap());
607    }
608
609    #[test]
610    fn test_find_ddl_references_for_exists_subquery() {
611        let sql = "SELECT * FROM T1 WHERE EXISTS (SELECT * FROM T2 WHERE T2.V1 = T1.V1);";
612        let sql_statements = DDL_AND_DML.to_owned() + sql;
613        let sql_statements = parse_sql(sql_statements);
614        let ddl_references = find_ddl_references(&sql_statements);
615
616        assert!(ddl_references.contains("t1"));
617        assert!(ddl_references.contains("t2"));
618
619        assert!(!ddl_references.contains("t3"));
620        assert!(!ddl_references.contains("m1"));
621    }
622}