risingwave_sqlsmith/
config.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::fmt;
17
18use rand::Rng;
19use serde::Deserialize;
20
21#[derive(Debug, Clone, Copy, Ord, PartialOrd, PartialEq, Eq, Hash, Deserialize)]
22#[serde(rename_all = "snake_case")]
23pub enum Syntax {
24    Where,
25    Agg,
26    Join,
27}
28
29impl fmt::Display for Syntax {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        let s = match self {
32            Syntax::Where => "where",
33            Syntax::Agg => "agg",
34            Syntax::Join => "join",
35        };
36        write!(f, "{}", s)
37    }
38}
39
40#[derive(Debug, Clone, Copy, Ord, PartialOrd, PartialEq, Eq, Hash, Deserialize)]
41#[serde(rename_all = "snake_case")]
42pub enum Feature {
43    Eowc,
44    NaturalJoin,
45    UsingJoin,
46    Except,
47}
48
49impl fmt::Display for Feature {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        let s = match self {
52            Feature::Eowc => "eowc",
53            Feature::NaturalJoin => "natural join",
54            Feature::UsingJoin => "using join",
55            Feature::Except => "except",
56        };
57        write!(f, "{}", s)
58    }
59}
60
61impl From<Syntax> for GenerateItem {
62    fn from(s: Syntax) -> Self {
63        GenerateItem::Syntax(s)
64    }
65}
66
67impl From<Feature> for GenerateItem {
68    fn from(f: Feature) -> Self {
69        GenerateItem::Feature(f)
70    }
71}
72
73/// Unified abstraction for syntax and feature
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum GenerateItem {
76    Syntax(Syntax),
77    Feature(Feature),
78}
79
80#[derive(Clone, Debug, Deserialize)]
81pub struct Configuration {
82    pub weight: BTreeMap<Syntax, u8>,
83
84    #[serde(default)]
85    pub feature: BTreeMap<Feature, bool>,
86}
87
88impl Default for Configuration {
89    fn default() -> Self {
90        Self::new("config.yml")
91    }
92}
93
94impl Configuration {
95    pub fn new(path: &str) -> Configuration {
96        let data = std::fs::read_to_string(path).unwrap();
97        let config: Configuration = serde_yaml::from_str(&data).unwrap();
98
99        for (syntax, weight) in &config.weight {
100            if *weight > 100 {
101                panic!(
102                    "Invalid weight {} for syntax '{}': must be in [0, 100]",
103                    weight, syntax
104                );
105            }
106        }
107
108        config
109    }
110
111    /// Decide whether to generate a syntax or enable a feature.
112    pub fn should_generate<R, T>(&self, item: T, rng: &mut R) -> bool
113    where
114        R: Rng,
115        T: Into<GenerateItem>,
116    {
117        match item.into() {
118            GenerateItem::Syntax(syntax) => {
119                let weight = self.weight.get(&syntax).cloned().unwrap_or(50);
120                rng.random_range(0..100) < weight
121            }
122            GenerateItem::Feature(feature) => *self.feature.get(&feature).unwrap_or(&false),
123        }
124    }
125
126    /// Dynamically update syntax weight.
127    pub fn set_weight(&mut self, syntax: Syntax, weight: u8) {
128        if weight > 100 {
129            panic!("Invalid weight {}: must be in [0, 100]", weight);
130        }
131
132        self.weight.insert(syntax, weight);
133    }
134
135    /// Dynamically enable/disable a feature.
136    pub fn set_enabled(&mut self, feature: Feature, enabled: bool) {
137        self.feature.insert(feature, enabled);
138    }
139
140    /// Enable features from command-line `--enable` arguments
141    pub fn enable_features_from_args(&mut self, features: &[String]) {
142        for feat in features {
143            let parsed = match feat.as_str() {
144                "eowc" => Feature::Eowc,
145                "natural_join" => Feature::NaturalJoin,
146                "using_join" => Feature::UsingJoin,
147                "except" => Feature::Except,
148                _ => panic!("Unknown feature: {}", feat),
149            };
150            self.set_enabled(parsed, true);
151        }
152    }
153}