1use std::collections::HashSet;
17use std::fmt::Write;
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use regex::Regex;
22use risingwave_sqlparser::ast::{
23 Cte, CteInner, Expr, FunctionArgExpr, Join, Query, Select, SetExpr, Statement, TableFactor,
24 TableWithJoins, With,
25};
26use tokio_postgres::Client;
27
28use crate::parse_sql;
29use crate::sqlreduce::checker::Checker;
30use crate::sqlreduce::reducer::Reducer;
31use crate::utils::{create_file, read_file_contents, write_to_file};
32
33type Result<A> = anyhow::Result<A>;
34
35pub async fn shrink_file(
37 input_file_path: &str,
38 output_file_path: &str,
39 client: Client,
40 restore_cmd: &str,
41) -> Result<()> {
42 let file_contents = read_file_contents(input_file_path)?;
44
45 let reduced_sql = shrink_statements(&file_contents)?;
47
48 let reduced_sql = shrink_with_reducer(&reduced_sql, client, restore_cmd).await?;
50
51 let mut file = create_file(output_file_path).unwrap();
53 write_to_file(&mut file, reduced_sql)
54}
55
56fn shrink_statements(sql: &str) -> Result<String> {
57 let sql_statements = parse_sql(sql);
58
59 let session_variable = sql_statements
61 .get(sql_statements.len() - 2)
62 .filter(|statement| matches!(statement, Statement::SetVariable { .. }));
63
64 let failing_query = sql_statements
65 .last()
66 .ok_or_else(|| anyhow!("Could not get last sql statement"))?;
67
68 let ddl_references = find_ddl_references(&sql_statements);
69
70 tracing::info!("[DDL REFERENCES]: {}", ddl_references.iter().join(", "));
71
72 let mut ddl = sql_statements
73 .iter()
74 .filter(|s| {
75 matches!(*s,
76 Statement::CreateView { name, .. } | Statement::CreateTable { name, .. }
77 if ddl_references.contains(&name.real_value()))
78 })
79 .collect();
80
81 let mut dml = sql_statements
82 .iter()
83 .filter(|s| {
84 matches!(*s,
85 Statement::Insert { table_name, .. }
86 if ddl_references.contains(&table_name.real_value()))
87 })
88 .collect();
89
90 let mut reduced_statements = vec![];
91 reduced_statements.append(&mut ddl);
92 reduced_statements.append(&mut dml);
93 if let Some(session_variable) = session_variable {
94 reduced_statements.push(session_variable);
95 }
96 reduced_statements.push(failing_query);
97
98 let sql = reduced_statements
99 .iter()
100 .fold(String::new(), |mut output, s| {
101 let _ = writeln!(output, "{s};");
102 output
103 });
104
105 Ok(sql)
106}
107
108async fn shrink_with_reducer(sql: &str, client: Client, restore_cmd: &str) -> Result<String> {
110 let sql_statements = parse_sql(sql);
111 let proceeding_stmts = sql_statements.split_last().unwrap().1.to_vec();
112 let checker = Checker::new(client, proceeding_stmts, restore_cmd.to_owned());
113 let mut reducer = Reducer::new(checker);
114
115 let reduced_sql = reducer.reduce(sql).await?;
116
117 Ok(reduced_sql)
118}
119
120pub(crate) fn find_ddl_references(sql_statements: &[Statement]) -> HashSet<String> {
121 let mut ddl_references = HashSet::new();
122 let mut sql_statements = sql_statements.iter().rev();
123 let failing = sql_statements.next().unwrap();
124 match failing {
125 Statement::Query(query) | Statement::CreateView { query, .. } => {
126 find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
127 }
128 _ => {}
129 };
130 for sql_statement in sql_statements {
131 match sql_statement {
132 Statement::Query(query) => {
133 find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
134 }
135 Statement::CreateView { query, name, .. }
136 if ddl_references.contains(&name.real_value()) =>
137 {
138 find_ddl_references_for_query(query.as_ref(), &mut ddl_references);
139 }
140 _ => {}
141 };
142 }
143 ddl_references
144}
145
146pub(crate) fn find_ddl_references_for_query(query: &Query, ddl_references: &mut HashSet<String>) {
147 let Query { with, body, .. } = query;
148 if let Some(With { cte_tables, .. }) = with {
149 for Cte { cte_inner, .. } in cte_tables {
150 if let CteInner::Query(query) = cte_inner {
151 find_ddl_references_for_query(query, ddl_references)
152 }
153 }
154 }
155 find_ddl_references_for_query_in_set_expr(body, ddl_references);
156}
157
158fn find_ddl_references_for_query_in_set_expr(
159 set_expr: &SetExpr,
160 ddl_references: &mut HashSet<String>,
161) {
162 match set_expr {
163 SetExpr::Select(box Select {
164 from,
165 selection,
166 projection,
167 group_by,
168 having,
169 ..
170 }) => {
171 for table_with_joins in from {
173 find_ddl_references_for_query_in_table_with_joins(table_with_joins, ddl_references);
174 }
175
176 if let Some(where_expr) = selection {
178 find_ddl_references_in_expr(where_expr, ddl_references);
179 }
180
181 for select_item in projection {
183 match select_item {
184 risingwave_sqlparser::ast::SelectItem::UnnamedExpr(expr) => {
185 find_ddl_references_in_expr(expr, ddl_references);
186 }
187 risingwave_sqlparser::ast::SelectItem::ExprWithAlias { expr, .. } => {
188 find_ddl_references_in_expr(expr, ddl_references);
189 }
190 _ => {}
191 }
192 }
193
194 for group_by_expr in group_by {
196 find_ddl_references_in_expr(group_by_expr, ddl_references);
197 }
198
199 if let Some(having_expr) = having {
201 find_ddl_references_in_expr(having_expr, ddl_references);
202 }
203 }
204 SetExpr::Query(q) => find_ddl_references_for_query(q, ddl_references),
205 SetExpr::SetOperation { left, right, .. } => {
206 find_ddl_references_for_query_in_set_expr(left, ddl_references);
207 find_ddl_references_for_query_in_set_expr(right, ddl_references);
208 }
209 SetExpr::Values(_) => {}
210 }
211}
212
213fn find_ddl_references_in_expr(expr: &Expr, ddl_references: &mut HashSet<String>) {
214 match expr {
215 Expr::Exists(subquery) => {
217 find_ddl_references_for_query(subquery, ddl_references);
218 }
219 Expr::Subquery(subquery) => {
221 find_ddl_references_for_query(subquery, ddl_references);
222 }
223 Expr::BinaryOp { left, right, .. } => {
225 find_ddl_references_in_expr(left, ddl_references);
226 find_ddl_references_in_expr(right, ddl_references);
227 }
228 Expr::UnaryOp { expr, .. } => {
230 find_ddl_references_in_expr(expr, ddl_references);
231 }
232 Expr::Function(function) => {
234 for arg in &function.arg_list.args {
235 match arg {
236 risingwave_sqlparser::ast::FunctionArg::Unnamed(func_arg_expr) => {
237 if let risingwave_sqlparser::ast::FunctionArgExpr::Expr(expr) =
238 func_arg_expr
239 {
240 find_ddl_references_in_expr(expr, ddl_references);
241 }
242 }
243 risingwave_sqlparser::ast::FunctionArg::Named { arg, .. } => {
244 if let risingwave_sqlparser::ast::FunctionArgExpr::Expr(expr) = arg {
245 find_ddl_references_in_expr(expr, ddl_references);
246 }
247 }
248 }
249 }
250 }
251 Expr::Case {
253 operand,
254 conditions,
255 results,
256 else_result,
257 ..
258 } => {
259 if let Some(operand_expr) = operand {
260 find_ddl_references_in_expr(operand_expr, ddl_references);
261 }
262 for condition in conditions {
263 find_ddl_references_in_expr(condition, ddl_references);
264 }
265 for result in results {
266 find_ddl_references_in_expr(result, ddl_references);
267 }
268 if let Some(else_expr) = else_result {
269 find_ddl_references_in_expr(else_expr, ddl_references);
270 }
271 }
272 Expr::Nested(inner_expr) => {
274 find_ddl_references_in_expr(inner_expr, ddl_references);
275 }
276 Expr::Array(array) => {
278 for expr in &array.elem {
279 find_ddl_references_in_expr(expr, ddl_references);
280 }
281 }
282 Expr::Row(exprs) => {
284 for expr in exprs {
285 find_ddl_references_in_expr(expr, ddl_references);
286 }
287 }
288 Expr::InList { expr, list, .. } => {
290 find_ddl_references_in_expr(expr, ddl_references);
291 for item in list {
292 find_ddl_references_in_expr(item, ddl_references);
293 }
294 }
295 Expr::InSubquery { expr, subquery, .. } => {
296 find_ddl_references_in_expr(expr, ddl_references);
297 find_ddl_references_for_query(subquery, ddl_references);
298 }
299 Expr::Between {
301 expr, low, high, ..
302 } => {
303 find_ddl_references_in_expr(expr, ddl_references);
304 find_ddl_references_in_expr(low, ddl_references);
305 find_ddl_references_in_expr(high, ddl_references);
306 }
307 Expr::Cast { expr, .. } | Expr::TryCast { expr, .. } => {
309 find_ddl_references_in_expr(expr, ddl_references);
310 }
311 Expr::Substring {
313 expr,
314 substring_from,
315 substring_for,
316 ..
317 } => {
318 find_ddl_references_in_expr(expr, ddl_references);
319 if let Some(from_expr) = substring_from {
320 find_ddl_references_in_expr(from_expr, ddl_references);
321 }
322 if let Some(for_expr) = substring_for {
323 find_ddl_references_in_expr(for_expr, ddl_references);
324 }
325 }
326 Expr::Trim {
327 expr, trim_where, ..
328 } => {
329 find_ddl_references_in_expr(expr, ddl_references);
330 if let Some(trim_where_field) = trim_where {
331 match trim_where_field {
332 risingwave_sqlparser::ast::TrimWhereField::Both
333 | risingwave_sqlparser::ast::TrimWhereField::Leading
334 | risingwave_sqlparser::ast::TrimWhereField::Trailing => {
335 }
337 }
338 }
339 }
340 Expr::Identifier(_)
342 | Expr::Value(_)
343 | Expr::Parameter { .. }
344 | Expr::Position { .. }
345 | Expr::Extract { .. }
346 | Expr::IsNull(_)
347 | Expr::IsNotNull(_)
348 | Expr::IsDistinctFrom(_, _)
349 | Expr::IsNotDistinctFrom(_, _) => {
350 }
352 _ => {
354 }
356 }
357}
358
359fn find_ddl_references_for_query_in_table_with_joins(
360 TableWithJoins { relation, joins }: &TableWithJoins,
361 ddl_references: &mut HashSet<String>,
362) {
363 find_ddl_references_for_query_in_table_factor(relation, ddl_references);
364 for Join { relation, .. } in joins {
365 find_ddl_references_for_query_in_table_factor(relation, ddl_references);
366 }
367}
368
369fn find_ddl_references_for_query_in_table_factor(
370 table_factor: &TableFactor,
371 ddl_references: &mut HashSet<String>,
372) {
373 match table_factor {
374 TableFactor::Table { name, .. } => {
375 ddl_references.insert(name.real_value());
376 }
377 TableFactor::Derived { subquery, .. } => {
378 find_ddl_references_for_query(subquery, ddl_references)
379 }
380 TableFactor::TableFunction { name, args, .. } => {
381 let name = name.real_value();
382 let regex = Regex::new(r"(?i)(tumble|hop)").unwrap();
384 if regex.is_match(&name) && args.len() >= 3 {
385 let table_name = &args[0];
386 if let FunctionArgExpr::Expr(Expr::Identifier(table_name)) = table_name.get_expr() {
387 ddl_references.insert(table_name.to_string().to_lowercase());
388 }
389 }
390 }
391 TableFactor::NestedJoin(table_with_joins) => {
392 find_ddl_references_for_query_in_table_with_joins(table_with_joins, ddl_references);
393 }
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 const DDL_AND_DML: &str = "
402CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
403CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
404CREATE TABLE T3 (V1 timestamp, V2 INT, V3 INT);
405CREATE MATERIALIZED VIEW M1 AS SELECT * FROM T1;
406CREATE MATERIALIZED VIEW M2 AS SELECT * FROM T2 LEFT JOIN T3 ON T2.V1 = T3.V2;
407CREATE MATERIALIZED VIEW M3 AS SELECT * FROM T1 LEFT JOIN T2;
408CREATE MATERIALIZED VIEW M4 AS SELECT * FROM M3;
409INSERT INTO T1 VALUES(0, 0, 1);
410INSERT INTO T1 VALUES(0, 0, 2);
411INSERT INTO T2 VALUES(0, 0, 3);
412INSERT INTO T2 VALUES(0, 0, 4);
413INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
414INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
415SET RW_TWO_PHASE_AGG=TRUE;
416 ";
417
418 fn sql_to_query(sql: &str) -> Box<Query> {
419 let sql_statement = parse_sql(sql).into_iter().next().unwrap();
420 match sql_statement {
421 Statement::Query(query) | Statement::CreateView { query, .. } => query,
422 _ => panic!("Last statement was not a query, can't shrink"),
423 }
424 }
425
426 #[test]
427 fn test_find_ddl_references_for_query_simple() {
428 let sql = "SELECT * FROM T1;";
429 let query = sql_to_query(sql);
430 let mut ddl_references = HashSet::new();
431 find_ddl_references_for_query(&query, &mut ddl_references);
432 println!("{:#?}", ddl_references);
433 assert!(ddl_references.contains("t1"));
434 }
435
436 #[test]
437 fn test_find_ddl_references_for_tumble() {
438 let sql = "SELECT * FROM TUMBLE(T3, V1, INTERVAL '3' DAY);";
439 let query = sql_to_query(sql);
440 let mut ddl_references = HashSet::new();
441 find_ddl_references_for_query(&query, &mut ddl_references);
442 println!("{:#?}", ddl_references);
443 assert!(ddl_references.contains("t3"));
444 }
445
446 #[test]
447 fn test_find_ddl_references_for_query_with_cte() {
448 let sql = "WITH WITH0 AS (SELECT * FROM M3) SELECT * FROM WITH0";
449 let sql_statements = DDL_AND_DML.to_owned() + sql;
450 let sql_statements = parse_sql(sql_statements);
451 let ddl_references = find_ddl_references(&sql_statements);
452 assert!(ddl_references.contains("m3"));
453 assert!(ddl_references.contains("t1"));
454 assert!(ddl_references.contains("t2"));
455
456 assert!(!ddl_references.contains("m4"));
457 assert!(!ddl_references.contains("t3"));
458 assert!(!ddl_references.contains("m1"));
459 assert!(!ddl_references.contains("m2"));
460 }
461
462 #[test]
463 fn test_find_ddl_references_for_query_with_mv_on_mv() {
464 let sql = "WITH WITH0 AS (SELECT * FROM M4) SELECT * FROM WITH0";
465 let sql_statements = DDL_AND_DML.to_owned() + sql;
466 let sql_statements = parse_sql(sql_statements);
467 let ddl_references = find_ddl_references(&sql_statements);
468 assert!(ddl_references.contains("m4"));
469 assert!(ddl_references.contains("m3"));
470 assert!(ddl_references.contains("t1"));
471 assert!(ddl_references.contains("t2"));
472
473 assert!(!ddl_references.contains("t3"));
474 assert!(!ddl_references.contains("m1"));
475 assert!(!ddl_references.contains("m2"));
476 }
477
478 #[test]
479 fn test_find_ddl_references_for_query_joins() {
480 let sql = "SELECT * FROM (T1 JOIN T2 ON T1.V1 = T2.V2) JOIN T3 ON T2.V1 = T3.V2";
481 let sql_statements = DDL_AND_DML.to_owned() + sql;
482 let sql_statements = parse_sql(sql_statements);
483 let ddl_references = find_ddl_references(&sql_statements);
484 assert!(ddl_references.contains("t1"));
485 assert!(ddl_references.contains("t2"));
486 assert!(ddl_references.contains("t3"));
487
488 assert!(!ddl_references.contains("m1"));
489 assert!(!ddl_references.contains("m2"));
490 assert!(!ddl_references.contains("m3"));
491 assert!(!ddl_references.contains("m4"));
492 }
493
494 #[test]
495 fn test_shrink_values() {
496 let query = "SELECT 1;";
497 let sql = DDL_AND_DML.to_owned() + query;
498 let expected = format!(
499 "\
500SET RW_TWO_PHASE_AGG = true;
501{query}
502"
503 );
504 assert_eq!(expected, shrink_statements(&sql).unwrap());
505 }
506
507 #[test]
508 fn test_shrink_simple_table() {
509 let query = "SELECT * FROM t1;";
510 let sql = DDL_AND_DML.to_owned() + query;
511 let expected = format!(
512 "\
513CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
514INSERT INTO T1 VALUES (0, 0, 1);
515INSERT INTO T1 VALUES (0, 0, 2);
516SET RW_TWO_PHASE_AGG = true;
517{query}
518"
519 );
520 assert_eq!(expected, shrink_statements(&sql).unwrap());
521 }
522
523 #[test]
524 fn test_shrink_simple_table_with_alias() {
525 let query = "SELECT * FROM t1 AS s1;";
526 let sql = DDL_AND_DML.to_owned() + query;
527 let expected = format!(
528 "\
529CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
530INSERT INTO T1 VALUES (0, 0, 1);
531INSERT INTO T1 VALUES (0, 0, 2);
532SET RW_TWO_PHASE_AGG = true;
533{query}
534"
535 );
536 assert_eq!(expected, shrink_statements(&sql).unwrap());
537 }
538
539 #[test]
540 fn test_shrink_join() {
541 let query = "SELECT * FROM (T1 JOIN T2 ON T1.V1 = T2.V2) JOIN T3 ON T2.V1 = T3.V2;";
542 let sql = DDL_AND_DML.to_owned() + query;
543 let expected = format!(
544 "\
545CREATE TABLE T1 (V1 INT, V2 INT, V3 INT);
546CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
547CREATE TABLE T3 (V1 TIMESTAMP, V2 INT, V3 INT);
548INSERT INTO T1 VALUES (0, 0, 1);
549INSERT INTO T1 VALUES (0, 0, 2);
550INSERT INTO T2 VALUES (0, 0, 3);
551INSERT INTO T2 VALUES (0, 0, 4);
552INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
553INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
554SET RW_TWO_PHASE_AGG = true;
555{query}
556"
557 );
558 assert_eq!(expected, shrink_statements(&sql).unwrap());
559 }
560
561 #[test]
562 fn test_shrink_tumble() {
563 let query = "SELECT * FROM TUMBLE(T3, V1, INTERVAL '3' DAY);";
564 let sql = DDL_AND_DML.to_owned() + query;
565 let expected = format!(
566 "\
567CREATE TABLE T3 (V1 TIMESTAMP, V2 INT, V3 INT);
568INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 5);
569INSERT INTO T3 VALUES (TIMESTAMP '00:00:00', 0, 6);
570SET RW_TWO_PHASE_AGG = true;
571{query}
572"
573 );
574 assert_eq!(expected, shrink_statements(&sql).unwrap());
575 }
576
577 #[test]
578 fn test_shrink_subquery() {
579 let query = "SELECT * FROM (SELECT V1 AS K1 FROM T2);";
580 let sql = DDL_AND_DML.to_owned() + query;
581 let expected = format!(
582 "\
583CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
584INSERT INTO T2 VALUES (0, 0, 3);
585INSERT INTO T2 VALUES (0, 0, 4);
586SET RW_TWO_PHASE_AGG = true;
587{query}
588"
589 );
590 assert_eq!(expected, shrink_statements(&sql).unwrap());
591 }
592
593 #[test]
594 fn test_shrink_mview() {
595 let query = "CREATE MATERIALIZED VIEW m5 AS SELECT * FROM (SELECT V1 AS K1 FROM T2);";
596 let sql = DDL_AND_DML.to_owned() + query;
597 let expected = format!(
598 "\
599CREATE TABLE T2 (V1 INT, V2 INT, V3 INT);
600INSERT INTO T2 VALUES (0, 0, 3);
601INSERT INTO T2 VALUES (0, 0, 4);
602SET RW_TWO_PHASE_AGG = true;
603{query}
604"
605 );
606 assert_eq!(expected, shrink_statements(&sql).unwrap());
607 }
608
609 #[test]
610 fn test_find_ddl_references_for_exists_subquery() {
611 let sql = "SELECT * FROM T1 WHERE EXISTS (SELECT * FROM T2 WHERE T2.V1 = T1.V1);";
612 let sql_statements = DDL_AND_DML.to_owned() + sql;
613 let sql_statements = parse_sql(sql_statements);
614 let ddl_references = find_ddl_references(&sql_statements);
615
616 assert!(ddl_references.contains("t1"));
617 assert!(ddl_references.contains("t2"));
618
619 assert!(!ddl_references.contains("t3"));
620 assert!(!ddl_references.contains("m1"));
621 }
622}