1use std::collections::HashSet;
17use std::fmt::Write;
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use regex::Regex;
22use risingwave_sqlparser::ast::{
23 Cte, CteInner, Expr, FunctionArgExpr, Join, Query, Select, SetExpr, Statement, TableFactor,
24 TableWithJoins, With,
25};
26
27use crate::parse_sql;
28use crate::utils::{create_file, read_file_contents, write_to_file};
29
30type Result<A> = anyhow::Result<A>;
31
32pub fn shrink_file(input_file_path: &str, output_file_path: &str) -> Result<()> {
34 let file_contents = read_file_contents(input_file_path)?;
36
37 let reduced_sql = shrink(&file_contents)?;
39
40 let mut file = create_file(output_file_path).unwrap();
42 write_to_file(&mut file, reduced_sql)
43}
44
45fn shrink(sql: &str) -> Result<String> {
46 let sql_statements = parse_sql(sql);
47
48 let session_variable = sql_statements
50 .get(sql_statements.len() - 2)
51 .filter(|statement| matches!(statement, Statement::SetVariable { .. }));
52
53 let failing_query = sql_statements
54 .last()
55 .ok_or_else(|| anyhow!("Could not get last sql statement"))?;
56
57 let ddl_references = find_ddl_references(&sql_statements);
58
59 tracing::info!("[DDL REFERENCES]: {}", ddl_references.iter().join(", "));
60
61 let mut ddl = sql_statements
62 .iter()
63 .filter(|s| {
64 matches!(*s,
65 Statement::CreateView { name, .. } | Statement::CreateTable { name, .. }
66 if ddl_references.contains(&name.real_value()))
67 })
68 .collect();
69
70 let mut dml = sql_statements
71 .iter()
72 .filter(|s| {
73 matches!(*s,
74 Statement::Insert { table_name, .. }
75 if ddl_references.contains(&table_name.real_value()))
76 })
77 .collect();
78
79 let mut reduced_statements = vec![];
80 reduced_statements.append(&mut ddl);
81 reduced_statements.append(&mut dml);
82 if let Some(session_variable) = session_variable {
83 reduced_statements.push(session_variable);
84 }
85 reduced_statements.push(failing_query);
86
87 let sql = reduced_statements
88 .iter()
89 .fold(String::new(), |mut output, s| {
90 let _ = writeln!(output, "{s};");
91 output
92 });
93
94 Ok(sql)
95}
96
97pub(crate) fn find_ddl_references(sql_statements: &[Statement]) -> HashSet<String> {
98 let mut ddl_references = HashSet::new();
99 let mut sql_statements = sql_statements.iter().rev();
100 let failing = sql_statements.next().unwrap();
101 match failing {
102 Statement::Query(query) | Statement::CreateView { query, .. } => {
103 find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
104 }
105 _ => {}
106 };
107 for sql_statement in sql_statements {
108 match sql_statement {
109 Statement::Query(query) => {
110 find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
111 }
112 Statement::CreateView { query, name, .. }
113 if ddl_references.contains(&name.real_value()) =>
114 {
115 find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
116 }
117 _ => {}
118 };
119 }
120 ddl_references
121}
122
123pub(crate) fn find_ddl_references_for_query(query: &Query, ddl_references: &mut HashSet<String>) {
124 let Query { with, body, .. } = query;
125 if let Some(With { cte_tables, .. }) = with {
126 for Cte { cte_inner, .. } in cte_tables {
127 if let CteInner::Query(query) = cte_inner {
128 find_ddl_references_for_query(query, ddl_references)
129 }
130 }
131 }
132 find_ddl_references_for_query_in_set_expr(body, ddl_references);
133}
134
135fn find_ddl_references_for_query_in_set_expr(
136 set_expr: &SetExpr,
137 ddl_references: &mut HashSet<String>,
138) {
139 match set_expr {
140 SetExpr::Select(box Select { from, .. }) => {
141 for table_with_joins in from {
142 find_ddl_references_for_query_in_table_with_joins(table_with_joins, ddl_references);
143 }
144 }
145 SetExpr::Query(q) => find_ddl_references_for_query(q, ddl_references),
146 SetExpr::SetOperation { left, right, .. } => {
147 find_ddl_references_for_query_in_set_expr(left, ddl_references);
148 find_ddl_references_for_query_in_set_expr(right, ddl_references);
149 }
150 SetExpr::Values(_) => {}
151 }
152}
153
154fn find_ddl_references_for_query_in_table_with_joins(
155 TableWithJoins { relation, joins }: &TableWithJoins,
156 ddl_references: &mut HashSet<String>,
157) {
158 find_ddl_references_for_query_in_table_factor(relation, ddl_references);
159 for Join { relation, .. } in joins {
160 find_ddl_references_for_query_in_table_factor(relation, ddl_references);
161 }
162}
163
164fn find_ddl_references_for_query_in_table_factor(
165 table_factor: &TableFactor,
166 ddl_references: &mut HashSet<String>,
167) {
168 match table_factor {
169 TableFactor::Table { name, .. } => {
170 ddl_references.insert(name.real_value());
171 }
172 TableFactor::Derived { subquery, .. } => {
173 find_ddl_references_for_query(subquery, ddl_references)
174 }
175 TableFactor::TableFunction { name, args, .. } => {
176 let name = name.real_value();
177 let regex = Regex::new(r"(?i)(tumble|hop)").unwrap();
179 if regex.is_match(&name) && args.len() >= 3 {
180 let table_name = &args[0];
181 if let FunctionArgExpr::Expr(Expr::Identifier(table_name)) = table_name.get_expr() {
182 ddl_references.insert(table_name.to_string().to_lowercase());
183 }
184 }
185 }
186 TableFactor::NestedJoin(table_with_joins) => {
187 find_ddl_references_for_query_in_table_with_joins(table_with_joins, ddl_references);
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 const DDL_AND_DML: &str = "
197CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
198CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
199CREATE TABLE T3 (V1 timestamp, V2 INT, V3 INT);
200CREATE MATERIALIZED VIEW M1 AS SELECT * FROM T1;
201CREATE MATERIALIZED VIEW M2 AS SELECT * FROM T2 LEFT JOIN T3 ON T2.V1 = T3.V2;
202CREATE MATERIALIZED VIEW M3 AS SELECT * FROM T1 LEFT JOIN T2;
203CREATE MATERIALIZED VIEW M4 AS SELECT * FROM M3;
204INSERT INTO T1 VALUES(0, 0, 1);
205INSERT INTO T1 VALUES(0, 0, 2);
206INSERT INTO T2 VALUES(0, 0, 3);
207INSERT INTO T2 VALUES(0, 0, 4);
208INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
209INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
210SET RW_TWO_PHASE_AGG=TRUE;
211 ";
212
213 fn sql_to_query(sql: &str) -> Box<Query> {
214 let sql_statement = parse_sql(sql).into_iter().next().unwrap();
215 match sql_statement {
216 Statement::Query(query) | Statement::CreateView { query, .. } => query,
217 _ => panic!("Last statement was not a query, can't shrink"),
218 }
219 }
220
221 #[test]
222 fn test_find_ddl_references_for_query_simple() {
223 let sql = "SELECT * FROM T1;";
224 let query = sql_to_query(sql);
225 let mut ddl_references = HashSet::new();
226 find_ddl_references_for_query(&query, &mut ddl_references);
227 println!("{:#?}", ddl_references);
228 assert!(ddl_references.contains("t1"));
229 }
230
231 #[test]
232 fn test_find_ddl_references_for_tumble() {
233 let sql = "SELECT * FROM TUMBLE(T3, V1, INTERVAL '3' DAY);";
234 let query = sql_to_query(sql);
235 let mut ddl_references = HashSet::new();
236 find_ddl_references_for_query(&query, &mut ddl_references);
237 println!("{:#?}", ddl_references);
238 assert!(ddl_references.contains("t3"));
239 }
240
241 #[test]
242 fn test_find_ddl_references_for_query_with_cte() {
243 let sql = "WITH WITH0 AS (SELECT * FROM M3) SELECT * FROM WITH0";
244 let sql_statements = DDL_AND_DML.to_owned() + sql;
245 let sql_statements = parse_sql(sql_statements);
246 let ddl_references = find_ddl_references(&sql_statements);
247 assert!(ddl_references.contains("m3"));
248 assert!(ddl_references.contains("t1"));
249 assert!(ddl_references.contains("t2"));
250
251 assert!(!ddl_references.contains("m4"));
252 assert!(!ddl_references.contains("t3"));
253 assert!(!ddl_references.contains("m1"));
254 assert!(!ddl_references.contains("m2"));
255 }
256
257 #[test]
258 fn test_find_ddl_references_for_query_with_mv_on_mv() {
259 let sql = "WITH WITH0 AS (SELECT * FROM M4) SELECT * FROM WITH0";
260 let sql_statements = DDL_AND_DML.to_owned() + sql;
261 let sql_statements = parse_sql(sql_statements);
262 let ddl_references = find_ddl_references(&sql_statements);
263 assert!(ddl_references.contains("m4"));
264 assert!(ddl_references.contains("m3"));
265 assert!(ddl_references.contains("t1"));
266 assert!(ddl_references.contains("t2"));
267
268 assert!(!ddl_references.contains("t3"));
269 assert!(!ddl_references.contains("m1"));
270 assert!(!ddl_references.contains("m2"));
271 }
272
273 #[test]
274 fn test_find_ddl_references_for_query_joins() {
275 let sql = "SELECT * FROM (T1 JOIN T2 ON T1.V1 = T2.V2) JOIN T3 ON T2.V1 = T3.V2";
276 let sql_statements = DDL_AND_DML.to_owned() + sql;
277 let sql_statements = parse_sql(sql_statements);
278 let ddl_references = find_ddl_references(&sql_statements);
279 assert!(ddl_references.contains("t1"));
280 assert!(ddl_references.contains("t2"));
281 assert!(ddl_references.contains("t3"));
282
283 assert!(!ddl_references.contains("m1"));
284 assert!(!ddl_references.contains("m2"));
285 assert!(!ddl_references.contains("m3"));
286 assert!(!ddl_references.contains("m4"));
287 }
288
289 #[test]
290 fn test_shrink_values() {
291 let query = "SELECT 1;";
292 let sql = DDL_AND_DML.to_owned() + query;
293 let expected = format!(
294 "\
295SET RW_TWO_PHASE_AGG = true;
296{query}
297"
298 );
299 assert_eq!(expected, shrink(&sql).unwrap());
300 }
301
302 #[test]
303 fn test_shrink_simple_table() {
304 let query = "SELECT * FROM t1;";
305 let sql = DDL_AND_DML.to_owned() + query;
306 let expected = format!(
307 "\
308CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
309INSERT INTO T1 VALUES (0, 0, 1);
310INSERT INTO T1 VALUES (0, 0, 2);
311SET RW_TWO_PHASE_AGG = true;
312{query}
313"
314 );
315 assert_eq!(expected, shrink(&sql).unwrap());
316 }
317
318 #[test]
319 fn test_shrink_simple_table_with_alias() {
320 let query = "SELECT * FROM t1 AS s1;";
321 let sql = DDL_AND_DML.to_owned() + query;
322 let expected = format!(
323 "\
324CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
325INSERT INTO T1 VALUES (0, 0, 1);
326INSERT INTO T1 VALUES (0, 0, 2);
327SET RW_TWO_PHASE_AGG = true;
328{query}
329"
330 );
331 assert_eq!(expected, shrink(&sql).unwrap());
332 }
333
334 #[test]
335 fn test_shrink_join() {
336 let query = "SELECT * FROM (T1 JOIN T2 ON T1.V1 = T2.V2) JOIN T3 ON T2.V1 = T3.V2;";
337 let sql = DDL_AND_DML.to_owned() + query;
338 let expected = format!(
339 "\
340CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
341CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
342CREATE TABLE T3 (V1 TIMESTAMP, V2 INT, V3 INT);
343INSERT INTO T1 VALUES (0, 0, 1);
344INSERT INTO T1 VALUES (0, 0, 2);
345INSERT INTO T2 VALUES (0, 0, 3);
346INSERT INTO T2 VALUES (0, 0, 4);
347INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
348INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
349SET RW_TWO_PHASE_AGG = true;
350{query}
351"
352 );
353 assert_eq!(expected, shrink(&sql).unwrap());
354 }
355
356 #[test]
357 fn test_shrink_tumble() {
358 let query = "SELECT * FROM TUMBLE(T3, V1, INTERVAL '3' DAY);";
359 let sql = DDL_AND_DML.to_owned() + query;
360 let expected = format!(
361 "\
362CREATE TABLE T3 (V1 TIMESTAMP, V2 INT, V3 INT);
363INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
364INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
365SET RW_TWO_PHASE_AGG = true;
366{query}
367"
368 );
369 assert_eq!(expected, shrink(&sql).unwrap());
370 }
371
372 #[test]
373 fn test_shrink_subquery() {
374 let query = "SELECT * FROM (SELECT V1 AS K1 FROM T2);";
375 let sql = DDL_AND_DML.to_owned() + query;
376 let expected = format!(
377 "\
378CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
379INSERT INTO T2 VALUES (0, 0, 3);
380INSERT INTO T2 VALUES (0, 0, 4);
381SET RW_TWO_PHASE_AGG = true;
382{query}
383"
384 );
385 assert_eq!(expected, shrink(&sql).unwrap());
386 }
387
388 #[test]
389 fn test_shrink_mview() {
390 let query = "CREATE MATERIALIZED VIEW m5 AS SELECT * FROM (SELECT V1 AS K1 FROM T2);";
391 let sql = DDL_AND_DML.to_owned() + query;
392 let expected = format!(
393 "\
394CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
395INSERT INTO T2 VALUES (0, 0, 3);
396INSERT INTO T2 VALUES (0, 0, 4);
397SET RW_TWO_PHASE_AGG = true;
398{query}
399"
400 );
401 assert_eq!(expected, shrink(&sql).unwrap());
402 }
403}