1use std::collections::HashSet;
17use std::fmt::Write;
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use regex::Regex;
22use risingwave_sqlparser::ast::{
23 Cte, CteInner, Expr, FunctionArgExpr, Join, Query, Select, SetExpr, Statement, TableFactor,
24 TableWithJoins, With,
25};
26use tokio_postgres::Client;
27
28use crate::parse_sql;
29use crate::sqlreduce::Strategy;
30use crate::sqlreduce::checker::Checker;
31use crate::sqlreduce::reducer::Reducer;
32use crate::utils::{create_file, read_file_contents, write_to_file};
33
34type Result<A> = anyhow::Result<A>;
35
36pub async fn shrink_file(
38 input_file_path: &str,
39 output_file_path: &str,
40 strategy: Strategy,
41 client: Client,
42 restore_cmd: &str,
43) -> Result<()> {
44 let file_contents = read_file_contents(input_file_path)?;
46
47 let reduced_sql = shrink_statements(&file_contents)?;
49
50 let reduced_sql = shrink(&reduced_sql, strategy, client, restore_cmd).await?;
52
53 let mut file = create_file(output_file_path).unwrap();
55 write_to_file(&mut file, reduced_sql)
56}
57
58fn shrink_statements(sql: &str) -> Result<String> {
59 let sql_statements = parse_sql(sql);
60
61 let session_variable = sql_statements
63 .get(sql_statements.len() - 2)
64 .filter(|statement| matches!(statement, Statement::SetVariable { .. }));
65
66 let failing_query = sql_statements
67 .last()
68 .ok_or_else(|| anyhow!("Could not get last sql statement"))?;
69
70 let ddl_references = find_ddl_references(&sql_statements);
71
72 tracing::info!("[DDL REFERENCES]: {}", ddl_references.iter().join(", "));
73
74 let mut ddl = sql_statements
75 .iter()
76 .filter(|s| {
77 matches!(*s,
78 Statement::CreateView { name, .. } | Statement::CreateTable { name, .. }
79 if ddl_references.contains(&name.real_value()))
80 })
81 .collect();
82
83 let mut dml = sql_statements
84 .iter()
85 .filter(|s| {
86 matches!(*s,
87 Statement::Insert { table_name, .. }
88 if ddl_references.contains(&table_name.real_value()))
89 })
90 .collect();
91
92 let mut reduced_statements = vec![];
93 reduced_statements.append(&mut ddl);
94 reduced_statements.append(&mut dml);
95 if let Some(session_variable) = session_variable {
96 reduced_statements.push(session_variable);
97 }
98 reduced_statements.push(failing_query);
99
100 let sql = reduced_statements
101 .iter()
102 .fold(String::new(), |mut output, s| {
103 let _ = writeln!(output, "{s};");
104 output
105 });
106
107 Ok(sql)
108}
109
110async fn shrink(
111 sql: &str,
112 strategy: Strategy,
113 client: Client,
114 restore_cmd: &str,
115) -> Result<String> {
116 let sql_statements = parse_sql(sql);
117 let proceeding_stmts = sql_statements.split_last().unwrap().1.to_vec();
118 let checker = Checker::new(client, proceeding_stmts, restore_cmd.to_owned());
119 let mut reducer = Reducer::new(checker, strategy);
120
121 let reduced_sql = reducer.reduce(sql).await?;
122
123 Ok(reduced_sql)
124}
125
126pub(crate) fn find_ddl_references(sql_statements: &[Statement]) -> HashSet<String> {
127 let mut ddl_references = HashSet::new();
128 let mut sql_statements = sql_statements.iter().rev();
129 let failing = sql_statements.next().unwrap();
130 match failing {
131 Statement::Query(query) | Statement::CreateView { query, .. } => {
132 find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
133 }
134 _ => {}
135 };
136 for sql_statement in sql_statements {
137 match sql_statement {
138 Statement::Query(query) => {
139 find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
140 }
141 Statement::CreateView { query, name, .. }
142 if ddl_references.contains(&name.real_value()) =>
143 {
144 find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
145 }
146 _ => {}
147 };
148 }
149 ddl_references
150}
151
152pub(crate) fn find_ddl_references_for_query(query: &Query, ddl_references: &mut HashSet<String>) {
153 let Query { with, body, .. } = query;
154 if let Some(With { cte_tables, .. }) = with {
155 for Cte { cte_inner, .. } in cte_tables {
156 if let CteInner::Query(query) = cte_inner {
157 find_ddl_references_for_query(query, ddl_references)
158 }
159 }
160 }
161 find_ddl_references_for_query_in_set_expr(body, ddl_references);
162}
163
164fn find_ddl_references_for_query_in_set_expr(
165 set_expr: &SetExpr,
166 ddl_references: &mut HashSet<String>,
167) {
168 match set_expr {
169 SetExpr::Select(box Select { from, .. }) => {
170 for table_with_joins in from {
171 find_ddl_references_for_query_in_table_with_joins(table_with_joins, ddl_references);
172 }
173 }
174 SetExpr::Query(q) => find_ddl_references_for_query(q, ddl_references),
175 SetExpr::SetOperation { left, right, .. } => {
176 find_ddl_references_for_query_in_set_expr(left, ddl_references);
177 find_ddl_references_for_query_in_set_expr(right, ddl_references);
178 }
179 SetExpr::Values(_) => {}
180 }
181}
182
183fn find_ddl_references_for_query_in_table_with_joins(
184 TableWithJoins { relation, joins }: &TableWithJoins,
185 ddl_references: &mut HashSet<String>,
186) {
187 find_ddl_references_for_query_in_table_factor(relation, ddl_references);
188 for Join { relation, .. } in joins {
189 find_ddl_references_for_query_in_table_factor(relation, ddl_references);
190 }
191}
192
193fn find_ddl_references_for_query_in_table_factor(
194 table_factor: &TableFactor,
195 ddl_references: &mut HashSet<String>,
196) {
197 match table_factor {
198 TableFactor::Table { name, .. } => {
199 ddl_references.insert(name.real_value());
200 }
201 TableFactor::Derived { subquery, .. } => {
202 find_ddl_references_for_query(subquery, ddl_references)
203 }
204 TableFactor::TableFunction { name, args, .. } => {
205 let name = name.real_value();
206 let regex = Regex::new(r"(?i)(tumble|hop)").unwrap();
208 if regex.is_match(&name) && args.len() >= 3 {
209 let table_name = &args[0];
210 if let FunctionArgExpr::Expr(Expr::Identifier(table_name)) = table_name.get_expr() {
211 ddl_references.insert(table_name.to_string().to_lowercase());
212 }
213 }
214 }
215 TableFactor::NestedJoin(table_with_joins) => {
216 find_ddl_references_for_query_in_table_with_joins(table_with_joins, ddl_references);
217 }
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 const DDL_AND_DML: &str = "
226CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
227CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
228CREATE TABLE T3 (V1 timestamp, V2 INT, V3 INT);
229CREATE MATERIALIZED VIEW M1 AS SELECT * FROM T1;
230CREATE MATERIALIZED VIEW M2 AS SELECT * FROM T2 LEFT JOIN T3 ON T2.V1 = T3.V2;
231CREATE MATERIALIZED VIEW M3 AS SELECT * FROM T1 LEFT JOIN T2;
232CREATE MATERIALIZED VIEW M4 AS SELECT * FROM M3;
233INSERT INTO T1 VALUES(0, 0, 1);
234INSERT INTO T1 VALUES(0, 0, 2);
235INSERT INTO T2 VALUES(0, 0, 3);
236INSERT INTO T2 VALUES(0, 0, 4);
237INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
238INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
239SET RW_TWO_PHASE_AGG=TRUE;
240 ";
241
242 fn sql_to_query(sql: &str) -> Box<Query> {
243 let sql_statement = parse_sql(sql).into_iter().next().unwrap();
244 match sql_statement {
245 Statement::Query(query) | Statement::CreateView { query, .. } => query,
246 _ => panic!("Last statement was not a query, can't shrink"),
247 }
248 }
249
250 #[test]
251 fn test_find_ddl_references_for_query_simple() {
252 let sql = "SELECT * FROM T1;";
253 let query = sql_to_query(sql);
254 let mut ddl_references = HashSet::new();
255 find_ddl_references_for_query(&query, &mut ddl_references);
256 println!("{:#?}", ddl_references);
257 assert!(ddl_references.contains("t1"));
258 }
259
260 #[test]
261 fn test_find_ddl_references_for_tumble() {
262 let sql = "SELECT * FROM TUMBLE(T3, V1, INTERVAL '3' DAY);";
263 let query = sql_to_query(sql);
264 let mut ddl_references = HashSet::new();
265 find_ddl_references_for_query(&query, &mut ddl_references);
266 println!("{:#?}", ddl_references);
267 assert!(ddl_references.contains("t3"));
268 }
269
270 #[test]
271 fn test_find_ddl_references_for_query_with_cte() {
272 let sql = "WITH WITH0 AS (SELECT * FROM M3) SELECT * FROM WITH0";
273 let sql_statements = DDL_AND_DML.to_owned() + sql;
274 let sql_statements = parse_sql(sql_statements);
275 let ddl_references = find_ddl_references(&sql_statements);
276 assert!(ddl_references.contains("m3"));
277 assert!(ddl_references.contains("t1"));
278 assert!(ddl_references.contains("t2"));
279
280 assert!(!ddl_references.contains("m4"));
281 assert!(!ddl_references.contains("t3"));
282 assert!(!ddl_references.contains("m1"));
283 assert!(!ddl_references.contains("m2"));
284 }
285
286 #[test]
287 fn test_find_ddl_references_for_query_with_mv_on_mv() {
288 let sql = "WITH WITH0 AS (SELECT * FROM M4) SELECT * FROM WITH0";
289 let sql_statements = DDL_AND_DML.to_owned() + sql;
290 let sql_statements = parse_sql(sql_statements);
291 let ddl_references = find_ddl_references(&sql_statements);
292 assert!(ddl_references.contains("m4"));
293 assert!(ddl_references.contains("m3"));
294 assert!(ddl_references.contains("t1"));
295 assert!(ddl_references.contains("t2"));
296
297 assert!(!ddl_references.contains("t3"));
298 assert!(!ddl_references.contains("m1"));
299 assert!(!ddl_references.contains("m2"));
300 }
301
302 #[test]
303 fn test_find_ddl_references_for_query_joins() {
304 let sql = "SELECT * FROM (T1 JOIN T2 ON T1.V1 = T2.V2) JOIN T3 ON T2.V1 = T3.V2";
305 let sql_statements = DDL_AND_DML.to_owned() + sql;
306 let sql_statements = parse_sql(sql_statements);
307 let ddl_references = find_ddl_references(&sql_statements);
308 assert!(ddl_references.contains("t1"));
309 assert!(ddl_references.contains("t2"));
310 assert!(ddl_references.contains("t3"));
311
312 assert!(!ddl_references.contains("m1"));
313 assert!(!ddl_references.contains("m2"));
314 assert!(!ddl_references.contains("m3"));
315 assert!(!ddl_references.contains("m4"));
316 }
317
318 #[test]
319 fn test_shrink_values() {
320 let query = "SELECT 1;";
321 let sql = DDL_AND_DML.to_owned() + query;
322 let expected = format!(
323 "\
324SET RW_TWO_PHASE_AGG = true;
325{query}
326"
327 );
328 assert_eq!(expected, shrink_statements(&sql).unwrap());
329 }
330
331 #[test]
332 fn test_shrink_simple_table() {
333 let query = "SELECT * FROM t1;";
334 let sql = DDL_AND_DML.to_owned() + query;
335 let expected = format!(
336 "\
337CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
338INSERT INTO T1 VALUES (0, 0, 1);
339INSERT INTO T1 VALUES (0, 0, 2);
340SET RW_TWO_PHASE_AGG = true;
341{query}
342"
343 );
344 assert_eq!(expected, shrink_statements(&sql).unwrap());
345 }
346
347 #[test]
348 fn test_shrink_simple_table_with_alias() {
349 let query = "SELECT * FROM t1 AS s1;";
350 let sql = DDL_AND_DML.to_owned() + query;
351 let expected = format!(
352 "\
353CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
354INSERT INTO T1 VALUES (0, 0, 1);
355INSERT INTO T1 VALUES (0, 0, 2);
356SET RW_TWO_PHASE_AGG = true;
357{query}
358"
359 );
360 assert_eq!(expected, shrink_statements(&sql).unwrap());
361 }
362
363 #[test]
364 fn test_shrink_join() {
365 let query = "SELECT * FROM (T1 JOIN T2 ON T1.V1 = T2.V2) JOIN T3 ON T2.V1 = T3.V2;";
366 let sql = DDL_AND_DML.to_owned() + query;
367 let expected = format!(
368 "\
369CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
370CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
371CREATE TABLE T3 (V1 TIMESTAMP, V2 INT, V3 INT);
372INSERT INTO T1 VALUES (0, 0, 1);
373INSERT INTO T1 VALUES (0, 0, 2);
374INSERT INTO T2 VALUES (0, 0, 3);
375INSERT INTO T2 VALUES (0, 0, 4);
376INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
377INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
378SET RW_TWO_PHASE_AGG = true;
379{query}
380"
381 );
382 assert_eq!(expected, shrink_statements(&sql).unwrap());
383 }
384
385 #[test]
386 fn test_shrink_tumble() {
387 let query = "SELECT * FROM TUMBLE(T3, V1, INTERVAL '3' DAY);";
388 let sql = DDL_AND_DML.to_owned() + query;
389 let expected = format!(
390 "\
391CREATE TABLE T3 (V1 TIMESTAMP, V2 INT, V3 INT);
392INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
393INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
394SET RW_TWO_PHASE_AGG = true;
395{query}
396"
397 );
398 assert_eq!(expected, shrink_statements(&sql).unwrap());
399 }
400
401 #[test]
402 fn test_shrink_subquery() {
403 let query = "SELECT * FROM (SELECT V1 AS K1 FROM T2);";
404 let sql = DDL_AND_DML.to_owned() + query;
405 let expected = format!(
406 "\
407CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
408INSERT INTO T2 VALUES (0, 0, 3);
409INSERT INTO T2 VALUES (0, 0, 4);
410SET RW_TWO_PHASE_AGG = true;
411{query}
412"
413 );
414 assert_eq!(expected, shrink_statements(&sql).unwrap());
415 }
416
417 #[test]
418 fn test_shrink_mview() {
419 let query = "CREATE MATERIALIZED VIEW m5 AS SELECT * FROM (SELECT V1 AS K1 FROM T2);";
420 let sql = DDL_AND_DML.to_owned() + query;
421 let expected = format!(
422 "\
423CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
424INSERT INTO T2 VALUES (0, 0, 3);
425INSERT INTO T2 VALUES (0, 0, 4);
426SET RW_TWO_PHASE_AGG = true;
427{query}
428"
429 );
430 assert_eq!(expected, shrink_statements(&sql).unwrap());
431 }
432}