risingwave_common/field_generator/
numeric.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::str::FromStr;
17
18use anyhow::Result;
19use rand::distr::uniform::SampleUniform;
20use rand::rngs::StdRng;
21use rand::{Rng, SeedableRng};
22use serde_json::json;
23
24use crate::field_generator::{NumericFieldRandomGenerator, NumericFieldSequenceGenerator};
25use crate::types::{Datum, F32, F64, Scalar};
26
27trait NumericType
28where
29    Self: FromStr
30        + Copy
31        + Debug
32        + Default
33        + PartialOrd
34        + num_traits::Num
35        + num_traits::NumAssignOps
36        + From<i16>
37        + TryFrom<u64>
38        + serde::Serialize
39        + SampleUniform,
40{
41}
42
43macro_rules! impl_numeric_type {
44    ($({ $random_variant_name:ident, $sequence_variant_name:ident,$field_type:ty }),*) => {
45        $(
46            impl NumericType for $field_type {}
47        )*
48    };
49}
50
51pub struct NumericFieldRandomConcrete<T> {
52    min: T,
53    max: T,
54    seed: u64,
55}
56
57#[derive(Default)]
58pub struct NumericFieldSequenceConcrete<T> {
59    start: T,
60    end: T,
61    cur: T,
62    offset: u64,
63    step: u64,
64}
65
66impl<T> NumericFieldRandomGenerator for NumericFieldRandomConcrete<T>
67where
68    T: NumericType + Scalar,
69    <T as FromStr>::Err: std::error::Error + Send + Sync + 'static,
70{
71    fn new(min_option: Option<String>, max_option: Option<String>, seed: u64) -> Result<Self>
72    where
73        Self: Sized,
74    {
75        let mut min = T::zero();
76        let mut max = T::from(i16::MAX);
77
78        if let Some(min_option) = min_option {
79            min = min_option.parse::<T>()?;
80        }
81        if let Some(max_option) = max_option {
82            max = max_option.parse::<T>()?;
83        }
84        assert!(min <= max);
85
86        Ok(Self { min, max, seed })
87    }
88
89    fn generate(&mut self, offset: u64) -> serde_json::Value {
90        let mut rng = StdRng::seed_from_u64(offset ^ self.seed);
91        let result = rng.random_range(self.min..=self.max);
92        json!(result)
93    }
94
95    fn generate_datum(&mut self, offset: u64) -> Datum {
96        let mut rng = StdRng::seed_from_u64(offset ^ self.seed);
97        let result = rng.random_range(self.min..=self.max);
98        Some(result.to_scalar_value())
99    }
100}
101impl<T> NumericFieldSequenceGenerator for NumericFieldSequenceConcrete<T>
102where
103    T: NumericType + Scalar,
104    <T as FromStr>::Err: std::error::Error + Send + Sync + 'static,
105    <T as TryFrom<u64>>::Error: Debug,
106{
107    fn new(
108        star_option: Option<String>,
109        end_option: Option<String>,
110        offset: u64,
111        step: u64,
112        event_offset: u64,
113    ) -> Result<Self>
114    where
115        Self: Sized,
116    {
117        let mut start = T::zero();
118        let mut end = T::from(i16::MAX);
119
120        if let Some(star_optiont) = star_option {
121            start = star_optiont.parse::<T>()?;
122        }
123        if let Some(end_option) = end_option {
124            end = end_option.parse::<T>()?;
125        }
126
127        assert!(start <= end);
128        Ok(Self {
129            start,
130            end,
131            offset,
132            step,
133            cur: T::try_from(event_offset).map_err(|_| {
134                anyhow::anyhow!("event offset is too big, offset: {}", event_offset,)
135            })?,
136        })
137    }
138
139    fn generate(&mut self) -> serde_json::Value {
140        let partition_result = self.start
141            + T::try_from(self.offset).unwrap()
142            + T::try_from(self.step).unwrap() * self.cur;
143        let partition_result = if partition_result > self.end {
144            None
145        } else {
146            Some(partition_result)
147        };
148        self.cur += T::one();
149        json!(partition_result)
150    }
151
152    fn generate_datum(&mut self) -> Datum {
153        let partition_result = self.start
154            + T::try_from(self.offset).unwrap()
155            + T::try_from(self.step).unwrap() * self.cur;
156        self.cur += T::one();
157        if partition_result > self.end {
158            None
159        } else {
160            Some(partition_result.to_scalar_value())
161        }
162    }
163}
164
165#[macro_export]
166macro_rules! for_all_fields_variants {
167    ($macro:ident) => {
168        $macro! {
169            { I16RandomField,I16SequenceField,i16 },
170            { I32RandomField,I32SequenceField,i32 },
171            { I64RandomField,I64SequenceField,i64 },
172            { F32RandomField,F32SequenceField,F32 },
173            { F64RandomField,F64SequenceField,F64 }
174        }
175    };
176}
177
178macro_rules! gen_random_field_alias {
179    ($({ $random_variant_name:ident, $sequence_variant_name:ident,$field_type:ty }),*) => {
180        $(
181            pub type $random_variant_name = NumericFieldRandomConcrete<$field_type>;
182        )*
183    };
184}
185
186macro_rules! gen_sequence_field_alias {
187    ($({ $random_variant_name:ident, $sequence_variant_name:ident,$field_type:ty }),*) => {
188        $(
189            pub type $sequence_variant_name = NumericFieldSequenceConcrete<$field_type>;
190        )*
191    };
192}
193
194for_all_fields_variants! { impl_numeric_type }
195for_all_fields_variants! { gen_random_field_alias }
196for_all_fields_variants! { gen_sequence_field_alias }
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use crate::types::DefaultOrd;
202
203    #[test]
204    fn test_sequence_field_generator() {
205        let mut i16_field =
206            I16SequenceField::new(Some("5".to_owned()), Some("10".to_owned()), 0, 1, 0).unwrap();
207        for i in 5..=10 {
208            assert_eq!(i16_field.generate(), json!(i));
209        }
210    }
211    #[test]
212    fn test_random_field_generator() {
213        let mut i64_field =
214            I64RandomField::new(Some("5".to_owned()), Some("10".to_owned()), 114).unwrap();
215        for i in 0..100 {
216            let res = i64_field.generate(i as u64);
217            assert!(res.is_number());
218            let res = res.as_i64().unwrap();
219            assert!((5..=10).contains(&res));
220        }
221
222        // test overflow
223        let mut i64_field = I64RandomField::new(None, None, 114).unwrap();
224        for i in 0..100 {
225            let res = i64_field.generate(i as u64);
226            assert!(res.is_number());
227            let res = res.as_i64().unwrap();
228            assert!(res >= 0);
229        }
230    }
231    #[test]
232    fn test_sequence_datum_generator() {
233        let mut f32_field =
234            F32SequenceField::new(Some("5.0".to_owned()), Some("10.0".to_owned()), 0, 1, 0)
235                .unwrap();
236
237        for i in 5..=10 {
238            assert_eq!(
239                f32_field.generate_datum(),
240                Some(F32::from(i as f32).to_scalar_value())
241            );
242        }
243    }
244    #[test]
245    fn test_random_datum_generator() {
246        let mut i32_field =
247            I32RandomField::new(Some("-5".to_owned()), Some("5".to_owned()), 123).unwrap();
248        let (lower, upper) = ((-5).to_scalar_value(), 5.to_scalar_value());
249        for i in 0..100 {
250            let res = i32_field.generate_datum(i as u64);
251            assert!(res.is_some());
252            let res = res.unwrap();
253            assert!(lower.default_cmp(&res).is_le() && res.default_cmp(&upper).is_le());
254        }
255    }
256
257    #[test]
258    fn test_sequence_field_generator_float() {
259        let mut f64_field =
260            F64SequenceField::new(Some("0".to_owned()), Some("10".to_owned()), 0, 1, 0).unwrap();
261        for i in 0..=10 {
262            assert_eq!(f64_field.generate(), json!(i as f64));
263        }
264
265        let mut f32_field =
266            F32SequenceField::new(Some("-5".to_owned()), Some("5".to_owned()), 0, 1, 0).unwrap();
267        for i in -5..=5 {
268            assert_eq!(f32_field.generate(), json!(i as f32));
269        }
270    }
271
272    #[test]
273    fn test_random_field_generator_float() {
274        let mut f64_field =
275            F64RandomField::new(Some("5".to_owned()), Some("10".to_owned()), 114).unwrap();
276        for i in 0..100 {
277            let res = f64_field.generate(i as u64);
278            assert!(res.is_number());
279            let res = res.as_f64().unwrap();
280            assert!((5. ..10.).contains(&res));
281        }
282
283        // test overflow
284        let mut f64_field = F64RandomField::new(None, None, 114).unwrap();
285        for i in 0..100 {
286            let res = f64_field.generate(i as u64);
287            assert!(res.is_number());
288            let res = res.as_f64().unwrap();
289            assert!(res >= 0.);
290        }
291
292        let mut f32_field =
293            F32RandomField::new(Some("5".to_owned()), Some("10".to_owned()), 114).unwrap();
294        for i in 0..100 {
295            let res = f32_field.generate(i as u64);
296            assert!(res.is_number());
297            // it seems there is no `as_f32`...
298            let res = res.as_f64().unwrap();
299            assert!((5. ..10.).contains(&res));
300        }
301
302        // test overflow
303        let mut f32_field = F32RandomField::new(None, None, 114).unwrap();
304        for i in 0..100 {
305            let res = f32_field.generate(i as u64);
306            assert!(res.is_number());
307            let res = res.as_f64().unwrap();
308            assert!(res >= 0.);
309        }
310    }
311}