risingwave_sqlsmith/
config.rs1use std::collections::BTreeMap;
16use std::fmt;
17
18use rand::Rng;
19use serde::Deserialize;
20
21#[derive(Debug, Clone, Copy, Ord, PartialOrd, PartialEq, Eq, Hash, Deserialize)]
22#[serde(rename_all = "snake_case")]
23pub enum Syntax {
24 Where,
25 Agg,
26 Join,
27}
28
29impl fmt::Display for Syntax {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 let s = match self {
32 Syntax::Where => "where",
33 Syntax::Agg => "agg",
34 Syntax::Join => "join",
35 };
36 write!(f, "{}", s)
37 }
38}
39
40#[derive(Debug, Clone, Copy, Ord, PartialOrd, PartialEq, Eq, Hash, Deserialize)]
41#[serde(rename_all = "snake_case")]
42pub enum Feature {
43 Eowc,
44 NaturalJoin,
45 UsingJoin,
46 Except,
47}
48
49impl fmt::Display for Feature {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 let s = match self {
52 Feature::Eowc => "eowc",
53 Feature::NaturalJoin => "natural join",
54 Feature::UsingJoin => "using join",
55 Feature::Except => "except",
56 };
57 write!(f, "{}", s)
58 }
59}
60
61impl From<Syntax> for GenerateItem {
62 fn from(s: Syntax) -> Self {
63 GenerateItem::Syntax(s)
64 }
65}
66
67impl From<Feature> for GenerateItem {
68 fn from(f: Feature) -> Self {
69 GenerateItem::Feature(f)
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum GenerateItem {
76 Syntax(Syntax),
77 Feature(Feature),
78}
79
80#[derive(Clone, Debug, Deserialize)]
81pub struct Configuration {
82 pub weight: BTreeMap<Syntax, u8>,
83
84 #[serde(default)]
85 pub feature: BTreeMap<Feature, bool>,
86}
87
88impl Default for Configuration {
89 fn default() -> Self {
90 Self::new("config.yml")
91 }
92}
93
94impl Configuration {
95 pub fn new(path: &str) -> Configuration {
96 let data = std::fs::read_to_string(path).unwrap();
97 let config: Configuration = serde_yaml::from_str(&data).unwrap();
98
99 for (syntax, weight) in &config.weight {
100 if *weight > 100 {
101 panic!(
102 "Invalid weight {} for syntax '{}': must be in [0, 100]",
103 weight, syntax
104 );
105 }
106 }
107
108 config
109 }
110
111 pub fn should_generate<R, T>(&self, item: T, rng: &mut R) -> bool
113 where
114 R: Rng,
115 T: Into<GenerateItem>,
116 {
117 match item.into() {
118 GenerateItem::Syntax(syntax) => {
119 let weight = self.weight.get(&syntax).cloned().unwrap_or(50);
120 rng.random_range(0..100) < weight
121 }
122 GenerateItem::Feature(feature) => *self.feature.get(&feature).unwrap_or(&false),
123 }
124 }
125
126 pub fn set_weight(&mut self, syntax: Syntax, weight: u8) {
128 if weight > 100 {
129 panic!("Invalid weight {}: must be in [0, 100]", weight);
130 }
131
132 self.weight.insert(syntax, weight);
133 }
134
135 pub fn set_enabled(&mut self, feature: Feature, enabled: bool) {
137 self.feature.insert(feature, enabled);
138 }
139
140 pub fn enable_features_from_args(&mut self, features: &[String]) {
142 for feat in features {
143 let parsed = match feat.as_str() {
144 "eowc" => Feature::Eowc,
145 "natural_join" => Feature::NaturalJoin,
146 "using_join" => Feature::UsingJoin,
147 "except" => Feature::Except,
148 _ => panic!("Unknown feature: {}", feat),
149 };
150 self.set_enabled(parsed, true);
151 }
152 }
153}