risingwave_sqlsmith/sqlreduce/passes/
pullup.rs1use risingwave_sqlparser::ast::{Array, Expr, Query, SelectItem, SetExpr, Statement};
16
17use crate::sqlreduce::passes::{Ast, Transform, extract_query, extract_query_mut};
18
19pub struct BinaryOperatorPullup;
34
35impl Transform for BinaryOperatorPullup {
36 fn name(&self) -> String {
37 "binary_operator_pullup".to_owned()
38 }
39
40 fn get_reduction_points(&self, ast: Ast) -> Vec<usize> {
41 let mut reduction_points = Vec::new();
42 if let Some(query) = extract_query(&ast)
43 && let SetExpr::Select(select) = &query.body
44 {
45 for (i, item) in select.projection.iter().enumerate() {
46 if let SelectItem::UnnamedExpr(expr) = item
47 && let Expr::BinaryOp { .. } = expr
48 {
49 reduction_points.push(i);
50 }
51 }
52 }
53
54 reduction_points
55 }
56
57 fn apply_on(&self, ast: &mut Ast, reduction_points: Vec<usize>) -> Ast {
58 if let Some(query) = extract_query_mut(ast)
59 && let SetExpr::Select(select) = &mut query.body
60 {
61 for i in reduction_points {
62 if let SelectItem::UnnamedExpr(ref mut expr) = select.projection[i]
63 && let Expr::BinaryOp { right, .. } = expr
64 {
65 *expr = *right.clone();
66 }
67 }
68 }
69
70 ast.clone()
71 }
72}
73
74pub struct CasePullup;
93
94impl Transform for CasePullup {
95 fn name(&self) -> String {
96 "case_pullup".to_owned()
97 }
98
99 fn get_reduction_points(&self, ast: Ast) -> Vec<usize> {
100 let mut reduction_points = Vec::new();
101 if let Some(query) = extract_query(&ast)
102 && let SetExpr::Select(select) = &query.body
103 {
104 for (i, item) in select.projection.iter().enumerate() {
105 if let SelectItem::UnnamedExpr(expr) = item
106 && let Expr::Case { .. } = expr
107 {
108 reduction_points.push(i);
109 }
110 }
111 }
112 reduction_points
113 }
114
115 fn apply_on(&self, ast: &mut Ast, reduction_points: Vec<usize>) -> Ast {
116 if let Some(query) = extract_query_mut(ast)
117 && let SetExpr::Select(select) = &mut query.body
118 {
119 for i in reduction_points {
120 if let SelectItem::UnnamedExpr(ref mut expr) = select.projection[i]
121 && let Expr::Case { results, .. } = expr
122 {
123 *expr = results[0].clone();
124 }
125 }
126 }
127
128 ast.clone()
129 }
130}
131
132pub struct RowPullup;
147
148impl Transform for RowPullup {
149 fn name(&self) -> String {
150 "row_pullup".to_owned()
151 }
152
153 fn get_reduction_points(&self, ast: Ast) -> Vec<usize> {
154 let mut reduction_points = Vec::new();
155 if let Some(query) = extract_query(&ast)
156 && let SetExpr::Select(select) = &query.body
157 {
158 for (i, item) in select.projection.iter().enumerate() {
159 if let SelectItem::UnnamedExpr(expr) = item
160 && let Expr::Row { .. } = expr
161 {
162 reduction_points.push(i);
163 }
164 }
165 }
166 reduction_points
167 }
168
169 fn apply_on(&self, ast: &mut Ast, reduction_points: Vec<usize>) -> Ast {
170 if let Some(query) = extract_query_mut(ast)
171 && let SetExpr::Select(select) = &mut query.body
172 {
173 for i in reduction_points {
174 if let SelectItem::UnnamedExpr(ref mut expr) = select.projection[i]
175 && let Expr::Row(elements) = expr
176 {
177 *expr = Expr::Row(vec![elements[0].clone()]);
178 }
179 }
180 }
181 ast.clone()
182 }
183}
184
185pub struct ArrayPullup;
199
200impl Transform for ArrayPullup {
201 fn name(&self) -> String {
202 "array_pullup".to_owned()
203 }
204
205 fn get_reduction_points(&self, ast: Ast) -> Vec<usize> {
206 let mut reduction_points = Vec::new();
207 if let Some(query) = extract_query(&ast)
208 && let SetExpr::Select(select) = &query.body
209 {
210 for (i, item) in select.projection.iter().enumerate() {
211 if let SelectItem::UnnamedExpr(expr) = item
212 && let Expr::Array { .. } = expr
213 {
214 reduction_points.push(i);
215 }
216 }
217 }
218 reduction_points
219 }
220
221 fn apply_on(&self, ast: &mut Ast, reduction_points: Vec<usize>) -> Ast {
222 if let Some(query) = extract_query_mut(ast)
223 && let SetExpr::Select(select) = &mut query.body
224 {
225 for i in reduction_points {
226 if let SelectItem::UnnamedExpr(ref mut expr) = select.projection[i]
227 && let Expr::Array(array) = expr
228 && let Some(elem) = array.elem.first()
229 {
230 *expr = Expr::Array(Array {
231 elem: vec![elem.clone()],
232 named: array.named,
233 });
234 }
235 }
236 }
237 ast.clone()
238 }
239}
240
241pub struct SetOperationPullup;
259
260impl Transform for SetOperationPullup {
261 fn name(&self) -> String {
262 "set_operation_pullup".to_owned()
263 }
264
265 fn get_reduction_points(&self, ast: Ast) -> Vec<usize> {
266 let mut reduction_points = Vec::new();
267 if let Some(query) = extract_query(&ast)
268 && let SetExpr::SetOperation { .. } = &query.body
269 {
270 reduction_points.push(0); reduction_points.push(1); }
273 reduction_points
274 }
275
276 fn apply_on(&self, ast: &mut Ast, reduction_points: Vec<usize>) -> Ast {
277 let mut new_ast = ast.clone();
278 if let Some(query) = extract_query_mut(&mut new_ast)
279 && let SetExpr::SetOperation { left, right, .. } = &mut query.body
280 {
281 if reduction_points.contains(&0) {
282 new_ast = Statement::Query(Box::new(Query {
284 body: *left.clone(),
285 ..query.clone()
286 }));
287 } else if reduction_points.contains(&1) {
288 new_ast = Statement::Query(Box::new(Query {
290 body: *right.clone(),
291 ..query.clone()
292 }));
293 }
294 }
295 new_ast
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use crate::parse_sql;
303
304 #[test]
305 fn test_binary_operator_pullup_with_single_binary() {
306 let sql = "SELECT 1 + 2 + 3;";
307 let ast = parse_sql(sql);
308 let reduction_points = BinaryOperatorPullup.get_reduction_points(ast[0].clone());
309 assert_eq!(reduction_points, vec![0]);
310
311 let new_ast = BinaryOperatorPullup.apply_on(&mut ast[0].clone(), reduction_points);
312 assert_eq!(new_ast, parse_sql("SELECT 3;")[0].clone());
313 }
314
315 #[test]
316 fn test_binary_operator_pullup_with_multiple_binary() {
317 let sql = "SELECT 1 + 2 + 3, 4 + 5 + 6;";
318 let ast = parse_sql(sql);
319 let reduction_points = BinaryOperatorPullup.get_reduction_points(ast[0].clone());
320 assert_eq!(reduction_points, vec![0, 1]);
321
322 let new_ast = BinaryOperatorPullup.apply_on(&mut ast[0].clone(), reduction_points);
323 assert_eq!(new_ast, parse_sql("SELECT 3, 6;")[0].clone());
324 }
325
326 #[test]
327 fn test_case_pullup_with_single_when() {
328 let sql = "SELECT CASE WHEN 1 = 1 THEN 1 ELSE 2 END;";
329 let ast = parse_sql(sql);
330 let reduction_points = CasePullup.get_reduction_points(ast[0].clone());
331 assert_eq!(reduction_points, vec![0]);
332
333 let new_ast = CasePullup.apply_on(&mut ast[0].clone(), reduction_points);
334 assert_eq!(new_ast, parse_sql("SELECT 1;")[0].clone());
335 }
336
337 #[test]
338 fn test_row_pullup_with_single_row() {
339 let sql = "SELECT ROW(1, 2, 3);";
340 let ast = parse_sql(sql);
341 let reduction_points = RowPullup.get_reduction_points(ast[0].clone());
342 assert_eq!(reduction_points, vec![0]);
343
344 let new_ast = RowPullup.apply_on(&mut ast[0].clone(), reduction_points);
345 assert_eq!(new_ast, parse_sql("SELECT ROW(1);")[0].clone());
346 }
347
348 #[test]
349 fn test_row_pullup_with_multiple_rows() {
350 let sql = "SELECT ROW(1, 2, 3), ROW(4, 5, 6);";
351 let ast = parse_sql(sql);
352 let reduction_points = RowPullup.get_reduction_points(ast[0].clone());
353 assert_eq!(reduction_points, vec![0, 1]);
354
355 let new_ast = RowPullup.apply_on(&mut ast[0].clone(), reduction_points);
356 assert_eq!(new_ast, parse_sql("SELECT ROW(1), ROW(4);")[0].clone());
357 }
358
359 #[test]
360 fn test_array_pullup_with_single_array() {
361 let sql = "SELECT ARRAY[1, 2, 3];";
362 let ast = parse_sql(sql);
363 let reduction_points = ArrayPullup.get_reduction_points(ast[0].clone());
364 assert_eq!(reduction_points, vec![0]);
365
366 let new_ast = ArrayPullup.apply_on(&mut ast[0].clone(), reduction_points);
367 assert_eq!(new_ast, parse_sql("SELECT ARRAY[1];")[0].clone());
368 }
369
370 #[test]
371 fn test_array_pullup_with_multiple_arrays() {
372 let sql = "SELECT ARRAY[1, 2, 3], ARRAY[4, 5, 6];";
373 let ast = parse_sql(sql);
374 let reduction_points = ArrayPullup.get_reduction_points(ast[0].clone());
375 assert_eq!(reduction_points, vec![0, 1]);
376
377 let new_ast = ArrayPullup.apply_on(&mut ast[0].clone(), reduction_points);
378 assert_eq!(new_ast, parse_sql("SELECT ARRAY[1], ARRAY[4];")[0].clone());
379 }
380
381 #[test]
382 fn test_case_pullup_with_multiple_when() {
383 let sql = "SELECT CASE WHEN 1 = 1 THEN 1 WHEN 2 = 2 THEN 2 ELSE 3 END;";
384 let ast = parse_sql(sql);
385 let reduction_points = CasePullup.get_reduction_points(ast[0].clone());
386 assert_eq!(reduction_points, vec![0]);
387
388 let new_ast = CasePullup.apply_on(&mut ast[0].clone(), reduction_points);
389 assert_eq!(new_ast, parse_sql("SELECT 1;")[0].clone());
390 }
391
392 #[test]
393 fn test_set_operation_pullup_union() {
394 let sql = "SELECT 1 UNION SELECT 2;";
395 let ast = parse_sql(sql);
396 let reduction_points = SetOperationPullup.get_reduction_points(ast[0].clone());
397 assert_eq!(reduction_points, vec![0, 1]);
398
399 let new_ast = SetOperationPullup.apply_on(&mut ast[0].clone(), reduction_points);
400 assert_eq!(new_ast, parse_sql("SELECT 1;")[0].clone());
401 }
402
403 #[test]
404 fn test_set_operation_pullup_intersect() {
405 let sql = "SELECT 1 INTERSECT SELECT 2;";
406 let ast = parse_sql(sql);
407 let reduction_points = SetOperationPullup.get_reduction_points(ast[0].clone());
408 assert_eq!(reduction_points, vec![0, 1]);
409
410 let new_ast = SetOperationPullup.apply_on(&mut ast[0].clone(), reduction_points);
411 assert_eq!(new_ast, parse_sql("SELECT 1;")[0].clone());
412 }
413
414 #[test]
415 fn test_set_operation_pullup_except() {
416 let sql = "SELECT 1 EXCEPT SELECT 2;";
417 let ast = parse_sql(sql);
418 let reduction_points = SetOperationPullup.get_reduction_points(ast[0].clone());
419 assert_eq!(reduction_points, vec![0, 1]);
420
421 let new_ast = SetOperationPullup.apply_on(&mut ast[0].clone(), reduction_points);
422 assert_eq!(new_ast, parse_sql("SELECT 1;")[0].clone());
423 }
424}