risingwave_sqlsmith/sql_gen/
dml.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{Result, bail};
16use itertools::Itertools;
17use rand::Rng;
18use risingwave_common::types::DataType;
19use risingwave_sqlparser::ast::Expr::BinaryOp;
20use risingwave_sqlparser::ast::{
21    Assignment, AssignmentValue, BinaryOperator, Expr, ObjectName, Query, SetExpr, Statement,
22    Values,
23};
24
25use crate::Table;
26use crate::sql_gen::SqlGenerator;
27
28impl<'a, R: Rng + 'a> SqlGenerator<'a, R> {
29    pub(crate) fn generate_insert_statement(
30        &mut self,
31        table: &Table,
32        row_count: usize,
33    ) -> Statement {
34        let table_name = ObjectName(vec![table.name.as_str().into()]);
35        let data_types = table
36            .columns
37            .iter()
38            .cloned()
39            .map(|c| c.data_type)
40            .collect_vec();
41        let values = self.gen_values(&data_types, row_count);
42        let source = Query {
43            with: None,
44            body: SetExpr::Values(Values(values)),
45            order_by: vec![],
46            limit: None,
47            offset: None,
48            fetch: None,
49        };
50        Statement::Insert {
51            table_name,
52            columns: vec![],
53            source: Box::new(source),
54            returning: vec![],
55        }
56    }
57
58    pub(crate) fn generate_update_statements(
59        &mut self,
60        tables: &[Table],
61        inserts: &[Statement],
62    ) -> Result<Vec<Statement>> {
63        let mut updates = vec![];
64        for insert in inserts {
65            if self.rng.random_bool(0.2) {
66                match insert {
67                    Statement::Insert {
68                        table_name, source, ..
69                    } => {
70                        let values = Self::extract_insert_values(source)?;
71                        let table = tables
72                            .iter()
73                            .find(|table| table.name == table_name.real_value())
74                            .expect("Inserted values should always have an existing table");
75                        let pk_indices = &table.pk_indices;
76                        let mut updates_for_insert =
77                            self.generate_update_statements_inner(table, values, pk_indices);
78                        updates.append(&mut updates_for_insert);
79                    }
80                    _ => bail!("Should only have insert statements"),
81                }
82            }
83        }
84        Ok(updates)
85    }
86
87    pub(crate) fn generate_update_statements_inner(
88        &mut self,
89        table: &Table,
90        values: &[Vec<Expr>],
91        pk_indices: &[usize],
92    ) -> Vec<Statement> {
93        let data_types = table
94            .columns
95            .iter()
96            .cloned()
97            .map(|c| c.data_type)
98            .collect_vec();
99        if pk_indices.is_empty() {
100            // do delete for a random subset of rows.
101            let delete_statements = self.generate_delete_statements(table, values);
102            // then insert back some number of rows.
103            let insert_statements = if delete_statements.is_empty() {
104                vec![]
105            } else {
106                let insert_statement =
107                    self.generate_insert_statement(table, delete_statements.len());
108                vec![insert_statement]
109            };
110            delete_statements
111                .into_iter()
112                .chain(insert_statements)
113                .collect()
114        } else {
115            let value_indices = (0..table.columns.len())
116                .filter(|i| !pk_indices.contains(i))
117                .collect_vec();
118            let update_values = values
119                .iter()
120                .filter_map(|row| {
121                    if self.rng.random_bool(0.1) {
122                        let mut updated_row = row.clone();
123                        for value_index in &value_indices {
124                            let data_type = &data_types[*value_index];
125                            updated_row[*value_index] = self.gen_simple_scalar(data_type)
126                        }
127                        Some(updated_row)
128                    } else {
129                        None
130                    }
131                })
132                .collect_vec();
133            update_values
134                .iter()
135                .map(|row| Self::row_to_update_statement(table, pk_indices, &value_indices, row))
136                .collect_vec()
137        }
138    }
139
140    fn row_to_update_statement(
141        table: &Table,
142        pk_indices: &[usize],
143        value_indices: &[usize],
144        row: &[Expr],
145    ) -> Statement {
146        let assignments = value_indices
147            .iter()
148            .copied()
149            .map(|i| {
150                let name = table.columns[i].name.as_str();
151                let id = vec![name.into()];
152                let value = AssignmentValue::Expr(row[i].clone());
153                Assignment { id, value }
154            })
155            .collect_vec();
156        assert!(!assignments.is_empty());
157        Statement::Update {
158            table_name: ObjectName::from_test_str(&table.name),
159            assignments,
160            selection: Some(Self::create_selection_expr(table, pk_indices, row)),
161            returning: vec![],
162        }
163    }
164
165    fn create_selection_expr(table: &Table, selected_indices: &[usize], row: &[Expr]) -> Expr {
166        assert!(!selected_indices.is_empty());
167        let match_exprs = selected_indices
168            .iter()
169            .copied()
170            .map(|i| {
171                let match_val = row[i].clone();
172                let match_col = Expr::Identifier(table.columns[i].name.as_str().into());
173
174                Expr::BinaryOp {
175                    left: Box::new(match_col),
176                    op: BinaryOperator::Eq,
177                    right: Box::new(match_val),
178                }
179            })
180            .collect_vec();
181        match_exprs
182            .into_iter()
183            .reduce(|l, r| BinaryOp {
184                left: Box::new(l),
185                op: BinaryOperator::And,
186                right: Box::new(r),
187            })
188            .expect("pk should be non empty")
189    }
190
191    fn generate_delete_statements(
192        &mut self,
193        table: &Table,
194        values: &[Vec<Expr>],
195    ) -> Vec<Statement> {
196        let selected = (0..table.columns.len()).collect_vec();
197        values
198            .iter()
199            .filter_map(|row| {
200                if self.rng.random_bool(0.1) {
201                    let selection = Some(Self::create_selection_expr(table, &selected, row));
202                    Some(Statement::Delete {
203                        table_name: ObjectName::from_test_str(&table.name),
204                        selection,
205                        returning: vec![],
206                    })
207                } else {
208                    None
209                }
210            })
211            .collect()
212    }
213
214    fn extract_insert_values(source: &Query) -> Result<&[Vec<Expr>]> {
215        let body = &source.body;
216        match body {
217            SetExpr::Values(values) => Ok(&values.0),
218            _ => bail!("Should not have insert values"),
219        }
220    }
221
222    fn gen_values(&mut self, data_types: &[DataType], row_count: usize) -> Vec<Vec<Expr>> {
223        (0..row_count).map(|_| self.gen_row(data_types)).collect()
224    }
225
226    fn gen_row(&mut self, data_types: &[DataType]) -> Vec<Expr> {
227        data_types
228            .iter()
229            .map(|typ| self.gen_simple_scalar(typ))
230            .collect()
231    }
232}