risingwave_sqlsmith/sql_gen/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Provides Data structures for query generation,
16//! and the interface for generating
17//! stream (MATERIALIZED VIEW) and batch query statements.
18
19use std::vec;
20
21use rand::Rng;
22use risingwave_common::types::DataType;
23use risingwave_frontend::bind_data_type;
24use risingwave_sqlparser::ast::{ColumnDef, Expr, Ident, ObjectName, Statement};
25
26mod agg;
27mod cast;
28mod expr;
29pub use expr::print_function_table;
30
31mod dml;
32mod functions;
33mod query;
34mod relation;
35mod scalar;
36mod time_window;
37mod types;
38mod utils;
39
40#[derive(Clone, Debug)]
41pub struct Table {
42    pub name: String,
43    pub columns: Vec<Column>,
44    pub pk_indices: Vec<usize>,
45    pub is_base_table: bool,
46}
47
48impl Table {
49    pub fn new(name: String, columns: Vec<Column>) -> Self {
50        Self {
51            name,
52            columns,
53            pk_indices: vec![],
54            is_base_table: false,
55        }
56    }
57
58    pub fn new_for_base_table(name: String, columns: Vec<Column>, pk_indices: Vec<usize>) -> Self {
59        Self {
60            name,
61            columns,
62            pk_indices,
63            is_base_table: true,
64        }
65    }
66
67    pub fn get_qualified_columns(&self) -> Vec<Column> {
68        self.columns
69            .iter()
70            .map(|c| Column {
71                name: format!("{}.{}", self.name, c.name),
72                data_type: c.data_type.clone(),
73            })
74            .collect()
75    }
76}
77
78/// Sqlsmith Column definition
79#[derive(Clone, Debug)]
80pub struct Column {
81    pub(crate) name: String,
82    pub(crate) data_type: DataType,
83}
84
85impl From<ColumnDef> for Column {
86    fn from(c: ColumnDef) -> Self {
87        Self {
88            name: c.name.real_value(),
89            data_type: bind_data_type(&c.data_type.expect("data type should not be none")).unwrap(),
90        }
91    }
92}
93
94#[derive(Copy, Clone)]
95pub(crate) struct SqlGeneratorContext {
96    can_agg: bool, // This is used to disable agg expr totally,
97    // Used in top level, where we want to test queries
98    // without aggregates.
99    inside_agg: bool,
100}
101
102impl SqlGeneratorContext {
103    pub fn new() -> Self {
104        SqlGeneratorContext {
105            can_agg: true,
106            inside_agg: false,
107        }
108    }
109
110    pub fn new_with_can_agg(can_agg: bool) -> Self {
111        Self {
112            can_agg,
113            inside_agg: false,
114        }
115    }
116
117    pub fn set_inside_agg(self) -> Self {
118        Self {
119            inside_agg: true,
120            ..self
121        }
122    }
123
124    pub fn can_gen_agg(self) -> bool {
125        self.can_agg && !self.inside_agg
126    }
127
128    pub fn is_inside_agg(self) -> bool {
129        self.inside_agg
130    }
131}
132
133pub(crate) struct SqlGenerator<'a, R: Rng> {
134    tables: Vec<Table>,
135    rng: &'a mut R,
136
137    /// Relation ID used to generate table names and aliases
138    relation_id: u32,
139
140    /// Relations bound in generated query.
141    /// We might not read from all tables.
142    bound_relations: Vec<Table>,
143
144    /// Columns bound in generated query.
145    /// May not contain all columns from `Self::bound_relations`.
146    /// e.g. GROUP BY clause will constrain `bound_columns`.
147    bound_columns: Vec<Column>,
148
149    /// `SqlGenerator` can be used in two execution modes:
150    /// 1. Generating Query Statements.
151    /// 2. Generating queries for CREATE MATERIALIZED VIEW.
152    ///    Under this mode certain restrictions and workarounds are applied
153    ///    for unsupported stream executors.
154    is_mview: bool,
155
156    recursion_weight: f64,
157    // /// Count number of subquery.
158    // /// We don't want too many per query otherwise it is hard to debug.
159    // with_statements: u64,
160}
161
162/// Generators
163impl<'a, R: Rng> SqlGenerator<'a, R> {
164    pub(crate) fn new(rng: &'a mut R, tables: Vec<Table>) -> Self {
165        SqlGenerator {
166            tables,
167            rng,
168            relation_id: 0,
169            bound_relations: vec![],
170            bound_columns: vec![],
171            is_mview: false,
172            recursion_weight: 0.3,
173        }
174    }
175
176    pub(crate) fn new_for_mview(rng: &'a mut R, tables: Vec<Table>) -> Self {
177        // distinct aggregate is not allowed for MV
178        SqlGenerator {
179            tables,
180            rng,
181            relation_id: 0,
182            bound_relations: vec![],
183            bound_columns: vec![],
184            is_mview: true,
185            recursion_weight: 0.3,
186        }
187    }
188
189    pub(crate) fn gen_batch_query_stmt(&mut self) -> Statement {
190        let (query, _) = self.gen_query();
191        Statement::Query(Box::new(query))
192    }
193
194    pub(crate) fn gen_mview_stmt(&mut self, name: &str) -> (Statement, Table) {
195        let (query, schema) = self.gen_query();
196        let query = Box::new(query);
197        let table = Table::new(name.to_owned(), schema);
198        let name = ObjectName(vec![Ident::new_unchecked(name)]);
199        let mview = Statement::CreateView {
200            or_replace: false,
201            materialized: true,
202            if_not_exists: false,
203            name,
204            columns: vec![],
205            query,
206            with_options: vec![],
207            emit_mode: None,
208        };
209        (mview, table)
210    }
211
212    /// 50/50 chance to be true/false.
213    fn flip_coin(&mut self) -> bool {
214        self.rng.random_bool(0.5)
215    }
216
217    /// Provide recursion bounds.
218    pub(crate) fn can_recurse(&mut self) -> bool {
219        if self.recursion_weight <= 0.0 {
220            return false;
221        }
222        let can_recurse = self.rng.random_bool(self.recursion_weight);
223        if can_recurse {
224            self.recursion_weight *= 0.9;
225            if self.recursion_weight < 0.05 {
226                self.recursion_weight = 0.0;
227            }
228        }
229        can_recurse
230    }
231}