risingwave_sqlsmith/sql_gen/
dml.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{Result, bail};
16use itertools::Itertools;
17use rand::Rng;
18use risingwave_common::types::DataType;
19use risingwave_sqlparser::ast::Expr::BinaryOp;
20use risingwave_sqlparser::ast::{
21    Assignment, AssignmentValue, BinaryOperator, Expr, ObjectName, Query, SetExpr, Statement,
22    Values,
23};
24
25use crate::Table;
26use crate::sql_gen::SqlGenerator;
27
28impl<'a, R: Rng + 'a> SqlGenerator<'a, R> {
29    pub(crate) fn generate_insert_statement(
30        &mut self,
31        table: &Table,
32        row_count: usize,
33    ) -> Statement {
34        let table_name = ObjectName(vec![table.name.as_str().into()]);
35        let data_types = table
36            .columns
37            .iter()
38            .cloned()
39            .map(|c| c.data_type)
40            .collect_vec();
41        let values = self.gen_values(&data_types, row_count);
42        let source = Query {
43            with: None,
44            body: SetExpr::Values(Values(values)),
45            order_by: vec![],
46            limit: None,
47            offset: None,
48            fetch: None,
49        };
50        Statement::Insert {
51            table_name,
52            columns: vec![],
53            source: Box::new(source),
54            returning: vec![],
55        }
56    }
57
58    pub(crate) fn generate_update_statements(
59        &mut self,
60        tables: &[Table],
61        inserts: &[Statement],
62    ) -> Result<Vec<Statement>> {
63        let mut updates = vec![];
64        for insert in inserts {
65            if self.rng.random_bool(0.2) {
66                match insert {
67                    Statement::Insert {
68                        table_name, source, ..
69                    } => {
70                        let values = Self::extract_insert_values(source)?;
71                        let table = tables
72                            .iter()
73                            .find(|table| table.name == table_name.real_value())
74                            .expect("Inserted values should always have an existing table");
75                        let pk_indices = &table.pk_indices;
76                        let mut updates_for_insert =
77                            self.generate_update_statements_inner(table, values, pk_indices);
78                        updates.append(&mut updates_for_insert);
79                    }
80                    _ => bail!("Should only have insert statements"),
81                }
82            }
83        }
84        Ok(updates)
85    }
86
87    pub(crate) fn generate_update_statements_inner(
88        &mut self,
89        table: &Table,
90        values: &[Vec<Expr>],
91        pk_indices: &[usize],
92    ) -> Vec<Statement> {
93        let data_types = table
94            .columns
95            .iter()
96            .cloned()
97            .map(|c| c.data_type)
98            .collect_vec();
99        if pk_indices.is_empty() {
100            // do delete for a random subset of rows.
101            let delete_statements = self.generate_delete_statements(table, values);
102            // then insert back some number of rows.
103            let insert_statements = if delete_statements.is_empty() {
104                vec![]
105            } else {
106                let insert_statement =
107                    self.generate_insert_statement(table, delete_statements.len());
108                vec![insert_statement]
109            };
110            delete_statements
111                .into_iter()
112                .chain(insert_statements)
113                .collect()
114        } else {
115            let value_indices = (0..table.columns.len())
116                .filter(|i| !pk_indices.contains(i))
117                .collect_vec();
118            let update_values = values
119                .iter()
120                .filter_map(|row| {
121                    if self.rng.random_bool(0.1) {
122                        let mut updated_row = row.clone();
123                        for value_index in &value_indices {
124                            let data_type = &data_types[*value_index];
125                            updated_row[*value_index] = self.gen_simple_scalar(data_type)
126                        }
127                        Some(updated_row)
128                    } else {
129                        None
130                    }
131                })
132                .collect_vec();
133            update_values
134                .iter()
135                .map(|row| Self::row_to_update_statement(table, pk_indices, &value_indices, row))
136                .collect_vec()
137        }
138    }
139
140    fn row_to_update_statement(
141        table: &Table,
142        pk_indices: &[usize],
143        value_indices: &[usize],
144        row: &[Expr],
145    ) -> Statement {
146        let assignments = value_indices
147            .iter()
148            .copied()
149            .map(|i| {
150                let id = vec![table.columns[i].base_name()];
151                let value = AssignmentValue::Expr(row[i].clone());
152                Assignment { id, value }
153            })
154            .collect_vec();
155        assert!(!assignments.is_empty());
156        Statement::Update {
157            table_name: ObjectName::from_test_str(&table.name),
158            assignments,
159            selection: Some(Self::create_selection_expr(table, pk_indices, row)),
160            returning: vec![],
161        }
162    }
163
164    fn create_selection_expr(table: &Table, selected_indices: &[usize], row: &[Expr]) -> Expr {
165        assert!(!selected_indices.is_empty());
166        let match_exprs = selected_indices
167            .iter()
168            .copied()
169            .map(|i| {
170                let match_val = row[i].clone();
171                let match_col = table.columns[i].name_expr();
172
173                Expr::BinaryOp {
174                    left: Box::new(match_col),
175                    op: BinaryOperator::Eq,
176                    right: Box::new(match_val),
177                }
178            })
179            .collect_vec();
180        match_exprs
181            .into_iter()
182            .reduce(|l, r| BinaryOp {
183                left: Box::new(l),
184                op: BinaryOperator::And,
185                right: Box::new(r),
186            })
187            .expect("pk should be non empty")
188    }
189
190    fn generate_delete_statements(
191        &mut self,
192        table: &Table,
193        values: &[Vec<Expr>],
194    ) -> Vec<Statement> {
195        let selected = (0..table.columns.len()).collect_vec();
196        values
197            .iter()
198            .filter_map(|row| {
199                if self.rng.random_bool(0.1) {
200                    let selection = Some(Self::create_selection_expr(table, &selected, row));
201                    Some(Statement::Delete {
202                        table_name: ObjectName::from_test_str(&table.name),
203                        selection,
204                        returning: vec![],
205                    })
206                } else {
207                    None
208                }
209            })
210            .collect()
211    }
212
213    fn extract_insert_values(source: &Query) -> Result<&[Vec<Expr>]> {
214        let body = &source.body;
215        match body {
216            SetExpr::Values(values) => Ok(&values.0),
217            _ => bail!("Should not have insert values"),
218        }
219    }
220
221    fn gen_values(&mut self, data_types: &[DataType], row_count: usize) -> Vec<Vec<Expr>> {
222        (0..row_count).map(|_| self.gen_row(data_types)).collect()
223    }
224
225    fn gen_row(&mut self, data_types: &[DataType]) -> Vec<Expr> {
226        data_types
227            .iter()
228            .map(|typ| self.gen_simple_scalar(typ))
229            .collect()
230    }
231}