risingwave_common/field_generator/
numeric.rs1use std::fmt::Debug;
16use std::str::FromStr;
17
18use anyhow::Result;
19use rand::distr::uniform::SampleUniform;
20use rand::rngs::StdRng;
21use rand::{Rng, SeedableRng};
22use serde_json::json;
23
24use crate::field_generator::{NumericFieldRandomGenerator, NumericFieldSequenceGenerator};
25use crate::types::{Datum, F32, F64, Scalar};
26
27trait NumericType
28where
29 Self: FromStr
30 + Copy
31 + Debug
32 + Default
33 + PartialOrd
34 + num_traits::Num
35 + num_traits::NumAssignOps
36 + From<i16>
37 + TryFrom<u64>
38 + serde::Serialize
39 + SampleUniform,
40{
41}
42
43macro_rules! impl_numeric_type {
44 ($({ $random_variant_name:ident, $sequence_variant_name:ident,$field_type:ty }),*) => {
45 $(
46 impl NumericType for $field_type {}
47 )*
48 };
49}
50
51pub struct NumericFieldRandomConcrete<T> {
52 min: T,
53 max: T,
54 seed: u64,
55}
56
57#[derive(Default)]
58pub struct NumericFieldSequenceConcrete<T> {
59 start: T,
60 end: T,
61 cur: T,
62 offset: u64,
63 step: u64,
64}
65
66impl<T> NumericFieldRandomGenerator for NumericFieldRandomConcrete<T>
67where
68 T: NumericType + Scalar,
69 <T as FromStr>::Err: std::error::Error + Send + Sync + 'static,
70{
71 fn new(min_option: Option<String>, max_option: Option<String>, seed: u64) -> Result<Self>
72 where
73 Self: Sized,
74 {
75 let mut min = T::zero();
76 let mut max = T::from(i16::MAX);
77
78 if let Some(min_option) = min_option {
79 min = min_option.parse::<T>()?;
80 }
81 if let Some(max_option) = max_option {
82 max = max_option.parse::<T>()?;
83 }
84 assert!(min <= max);
85
86 Ok(Self { min, max, seed })
87 }
88
89 fn generate(&mut self, offset: u64) -> serde_json::Value {
90 let mut rng = StdRng::seed_from_u64(offset ^ self.seed);
91 let result = rng.random_range(self.min..=self.max);
92 json!(result)
93 }
94
95 fn generate_datum(&mut self, offset: u64) -> Datum {
96 let mut rng = StdRng::seed_from_u64(offset ^ self.seed);
97 let result = rng.random_range(self.min..=self.max);
98 Some(result.to_scalar_value())
99 }
100}
101impl<T> NumericFieldSequenceGenerator for NumericFieldSequenceConcrete<T>
102where
103 T: NumericType + Scalar,
104 <T as FromStr>::Err: std::error::Error + Send + Sync + 'static,
105 <T as TryFrom<u64>>::Error: Debug,
106{
107 fn new(
108 star_option: Option<String>,
109 end_option: Option<String>,
110 offset: u64,
111 step: u64,
112 event_offset: u64,
113 ) -> Result<Self>
114 where
115 Self: Sized,
116 {
117 let mut start = T::zero();
118 let mut end = T::from(i16::MAX);
119
120 if let Some(star_optiont) = star_option {
121 start = star_optiont.parse::<T>()?;
122 }
123 if let Some(end_option) = end_option {
124 end = end_option.parse::<T>()?;
125 }
126
127 assert!(start <= end);
128 Ok(Self {
129 start,
130 end,
131 offset,
132 step,
133 cur: T::try_from(event_offset).map_err(|_| {
134 anyhow::anyhow!("event offset is too big, offset: {}", event_offset,)
135 })?,
136 })
137 }
138
139 fn generate(&mut self) -> serde_json::Value {
140 let partition_result = self.start
141 + T::try_from(self.offset).unwrap()
142 + T::try_from(self.step).unwrap() * self.cur;
143 let partition_result = if partition_result > self.end {
144 None
145 } else {
146 Some(partition_result)
147 };
148 self.cur += T::one();
149 json!(partition_result)
150 }
151
152 fn generate_datum(&mut self) -> Datum {
153 let partition_result = self.start
154 + T::try_from(self.offset).unwrap()
155 + T::try_from(self.step).unwrap() * self.cur;
156 self.cur += T::one();
157 if partition_result > self.end {
158 None
159 } else {
160 Some(partition_result.to_scalar_value())
161 }
162 }
163}
164
165#[macro_export]
166macro_rules! for_all_fields_variants {
167 ($macro:ident) => {
168 $macro! {
169 { I16RandomField,I16SequenceField,i16 },
170 { I32RandomField,I32SequenceField,i32 },
171 { I64RandomField,I64SequenceField,i64 },
172 { F32RandomField,F32SequenceField,F32 },
173 { F64RandomField,F64SequenceField,F64 }
174 }
175 };
176}
177
178macro_rules! gen_random_field_alias {
179 ($({ $random_variant_name:ident, $sequence_variant_name:ident,$field_type:ty }),*) => {
180 $(
181 pub type $random_variant_name = NumericFieldRandomConcrete<$field_type>;
182 )*
183 };
184}
185
186macro_rules! gen_sequence_field_alias {
187 ($({ $random_variant_name:ident, $sequence_variant_name:ident,$field_type:ty }),*) => {
188 $(
189 pub type $sequence_variant_name = NumericFieldSequenceConcrete<$field_type>;
190 )*
191 };
192}
193
194for_all_fields_variants! { impl_numeric_type }
195for_all_fields_variants! { gen_random_field_alias }
196for_all_fields_variants! { gen_sequence_field_alias }
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::types::DefaultOrd;
202
203 #[test]
204 fn test_sequence_field_generator() {
205 let mut i16_field =
206 I16SequenceField::new(Some("5".to_owned()), Some("10".to_owned()), 0, 1, 0).unwrap();
207 for i in 5..=10 {
208 assert_eq!(i16_field.generate(), json!(i));
209 }
210 }
211 #[test]
212 fn test_random_field_generator() {
213 let mut i64_field =
214 I64RandomField::new(Some("5".to_owned()), Some("10".to_owned()), 114).unwrap();
215 for i in 0..100 {
216 let res = i64_field.generate(i as u64);
217 assert!(res.is_number());
218 let res = res.as_i64().unwrap();
219 assert!((5..=10).contains(&res));
220 }
221
222 let mut i64_field = I64RandomField::new(None, None, 114).unwrap();
224 for i in 0..100 {
225 let res = i64_field.generate(i as u64);
226 assert!(res.is_number());
227 let res = res.as_i64().unwrap();
228 assert!(res >= 0);
229 }
230 }
231 #[test]
232 fn test_sequence_datum_generator() {
233 let mut f32_field =
234 F32SequenceField::new(Some("5.0".to_owned()), Some("10.0".to_owned()), 0, 1, 0)
235 .unwrap();
236
237 for i in 5..=10 {
238 assert_eq!(
239 f32_field.generate_datum(),
240 Some(F32::from(i as f32).to_scalar_value())
241 );
242 }
243 }
244 #[test]
245 fn test_random_datum_generator() {
246 let mut i32_field =
247 I32RandomField::new(Some("-5".to_owned()), Some("5".to_owned()), 123).unwrap();
248 let (lower, upper) = ((-5).to_scalar_value(), 5.to_scalar_value());
249 for i in 0..100 {
250 let res = i32_field.generate_datum(i as u64);
251 assert!(res.is_some());
252 let res = res.unwrap();
253 assert!(lower.default_cmp(&res).is_le() && res.default_cmp(&upper).is_le());
254 }
255 }
256
257 #[test]
258 fn test_sequence_field_generator_float() {
259 let mut f64_field =
260 F64SequenceField::new(Some("0".to_owned()), Some("10".to_owned()), 0, 1, 0).unwrap();
261 for i in 0..=10 {
262 assert_eq!(f64_field.generate(), json!(i as f64));
263 }
264
265 let mut f32_field =
266 F32SequenceField::new(Some("-5".to_owned()), Some("5".to_owned()), 0, 1, 0).unwrap();
267 for i in -5..=5 {
268 assert_eq!(f32_field.generate(), json!(i as f32));
269 }
270 }
271
272 #[test]
273 fn test_random_field_generator_float() {
274 let mut f64_field =
275 F64RandomField::new(Some("5".to_owned()), Some("10".to_owned()), 114).unwrap();
276 for i in 0..100 {
277 let res = f64_field.generate(i as u64);
278 assert!(res.is_number());
279 let res = res.as_f64().unwrap();
280 assert!((5. ..10.).contains(&res));
281 }
282
283 let mut f64_field = F64RandomField::new(None, None, 114).unwrap();
285 for i in 0..100 {
286 let res = f64_field.generate(i as u64);
287 assert!(res.is_number());
288 let res = res.as_f64().unwrap();
289 assert!(res >= 0.);
290 }
291
292 let mut f32_field =
293 F32RandomField::new(Some("5".to_owned()), Some("10".to_owned()), 114).unwrap();
294 for i in 0..100 {
295 let res = f32_field.generate(i as u64);
296 assert!(res.is_number());
297 let res = res.as_f64().unwrap();
299 assert!((5. ..10.).contains(&res));
300 }
301
302 let mut f32_field = F32RandomField::new(None, None, 114).unwrap();
304 for i in 0..100 {
305 let res = f32_field.generate(i as u64);
306 assert!(res.is_number());
307 let res = res.as_f64().unwrap();
308 assert!(res >= 0.);
309 }
310 }
311}