risingwave_common/system_param/
adaptive_parallelism_strategy.rs1use std::cmp::{max, min};
16use std::fmt::{Display, Formatter};
17use std::num::NonZeroUsize;
18use std::str::FromStr;
19
20use regex::Regex;
21use risingwave_common::system_param::ParamValue;
22use serde::{Deserialize, Serialize};
23use thiserror::Error;
24
25#[derive(PartialEq, Copy, Clone, Debug, Default, Serialize, Deserialize)]
28#[serde(try_from = "String", into = "String")]
29pub enum AdaptiveParallelismStrategy {
30 #[default]
31 Auto,
32 Full,
33 Bounded(NonZeroUsize),
34 Ratio(f32),
35}
36
37impl TryFrom<String> for AdaptiveParallelismStrategy {
38 type Error = ParallelismStrategyParseError;
39
40 fn try_from(s: String) -> Result<Self, Self::Error> {
41 s.parse()
42 }
43}
44
45impl Display for AdaptiveParallelismStrategy {
46 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
47 match self {
48 AdaptiveParallelismStrategy::Auto => write!(f, "AUTO"),
49 AdaptiveParallelismStrategy::Full => write!(f, "FULL"),
50 AdaptiveParallelismStrategy::Bounded(n) => write!(f, "BOUNDED({})", n),
51 AdaptiveParallelismStrategy::Ratio(r) => write!(f, "RATIO({})", r),
52 }
53 }
54}
55
56impl From<AdaptiveParallelismStrategy> for String {
57 fn from(val: AdaptiveParallelismStrategy) -> Self {
58 val.to_string()
59 }
60}
61
62#[derive(Error, Debug)]
63pub enum ParallelismStrategyParseError {
64 #[error("Unsupported strategy: {0}")]
65 UnsupportedStrategy(String),
66
67 #[error("Parse error: {0}")]
68 ParseIntError(#[from] std::num::ParseIntError),
69
70 #[error("Parse error: {0}")]
71 ParseFloatError(#[from] std::num::ParseFloatError),
72
73 #[error("Invalid value for Bounded strategy: must be positive integer")]
74 InvalidBoundedValue,
75
76 #[error("Invalid value for Ratio strategy: must be between 0.0 and 1.0")]
77 InvalidRatioValue,
78}
79
80impl AdaptiveParallelismStrategy {
81 pub fn compute_target_parallelism(&self, current_parallelism: usize) -> usize {
82 match self {
83 AdaptiveParallelismStrategy::Auto | AdaptiveParallelismStrategy::Full => {
84 current_parallelism
85 }
86 AdaptiveParallelismStrategy::Bounded(n) => min(n.get(), current_parallelism),
87 AdaptiveParallelismStrategy::Ratio(r) => {
88 max((current_parallelism as f32 * r).floor() as usize, 1)
89 }
90 }
91 }
92}
93
94pub fn parse_strategy(
95 input: &str,
96) -> Result<AdaptiveParallelismStrategy, ParallelismStrategyParseError> {
97 let lower_input = input.to_lowercase();
98
99 match lower_input.as_str() {
101 "auto" => return Ok(AdaptiveParallelismStrategy::Auto),
102 "full" => return Ok(AdaptiveParallelismStrategy::Full),
103 _ => (),
104 }
105
106 fn bounded_re() -> &'static Regex {
108 static RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
109 RE.get_or_init(|| Regex::new(r"(?i)^bounded\((?<value>\d+)\)$").unwrap())
110 }
111
112 fn ratio_re() -> &'static Regex {
113 static RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
114 RE.get_or_init(|| Regex::new(r"(?i)^ratio\((?<value>[+-]?\d+(?:\.\d+)?)\)$").unwrap())
115 }
116
117 if let Some(caps) = bounded_re().captures(&lower_input) {
119 let value_str = caps.name("value").unwrap().as_str();
120 let value: usize = value_str.parse()?;
121
122 let value =
123 NonZeroUsize::new(value).ok_or(ParallelismStrategyParseError::InvalidBoundedValue)?;
124
125 return Ok(AdaptiveParallelismStrategy::Bounded(value));
126 }
127
128 if let Some(caps) = ratio_re().captures(&lower_input) {
130 let value_str = caps.name("value").unwrap().as_str();
131 let value: f32 = value_str.parse()?;
132
133 if !(0.0..=1.0).contains(&value) {
134 return Err(ParallelismStrategyParseError::InvalidRatioValue);
135 }
136
137 return Ok(AdaptiveParallelismStrategy::Ratio(value));
138 }
139
140 Err(ParallelismStrategyParseError::UnsupportedStrategy(
142 input.to_owned(),
143 ))
144}
145
146impl FromStr for AdaptiveParallelismStrategy {
147 type Err = ParallelismStrategyParseError;
148
149 fn from_str(s: &str) -> Result<Self, Self::Err> {
150 parse_strategy(s)
151 }
152}
153
154impl ParamValue for AdaptiveParallelismStrategy {
155 type Borrowed<'a> = AdaptiveParallelismStrategy;
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn test_valid_strategies() {
164 assert_eq!(
165 parse_strategy("Auto").unwrap(),
166 AdaptiveParallelismStrategy::Auto
167 );
168 assert_eq!(
169 parse_strategy("FULL").unwrap(),
170 AdaptiveParallelismStrategy::Full
171 );
172
173 let bounded = parse_strategy("Bounded(42)").unwrap();
174 assert!(matches!(bounded, AdaptiveParallelismStrategy::Bounded(n) if n.get() == 42));
175
176 let ratio = parse_strategy("Ratio(0.75)").unwrap();
177 assert!(matches!(ratio, AdaptiveParallelismStrategy::Ratio(0.75)));
178 }
179
180 #[test]
181 fn test_invalid_values() {
182 assert!(matches!(
184 parse_strategy("Bounded(0)"),
185 Err(ParallelismStrategyParseError::InvalidBoundedValue)
186 ));
187
188 assert!(matches!(
190 parse_strategy("Ratio(1.1)"),
191 Err(ParallelismStrategyParseError::InvalidRatioValue)
192 ));
193 assert!(matches!(
194 parse_strategy("Ratio(-0.5)"),
195 Err(ParallelismStrategyParseError::InvalidRatioValue)
196 ));
197 assert!(matches!(
198 parse_strategy("Ratio(-1)"),
199 Err(ParallelismStrategyParseError::InvalidRatioValue)
200 ));
201
202 assert!(matches!(
204 parse_strategy("Ratio(-0.a)"),
205 Err(ParallelismStrategyParseError::UnsupportedStrategy(_))
206 ));
207
208 assert!(matches!(
210 parse_strategy("Bounded(-5)"),
211 Err(ParallelismStrategyParseError::UnsupportedStrategy(_))
212 ));
213 }
214
215 #[test]
216 fn test_unsupported_formats() {
217 assert!(matches!(
218 parse_strategy("Invalid"),
219 Err(ParallelismStrategyParseError::UnsupportedStrategy(_))
220 ));
221
222 assert!(matches!(
223 parse_strategy("Auto(5)"),
224 Err(ParallelismStrategyParseError::UnsupportedStrategy(_))
225 ));
226 }
227
228 #[test]
229 fn test_auto_full_strategies() {
230 let auto = AdaptiveParallelismStrategy::Auto;
231 let full = AdaptiveParallelismStrategy::Full;
232
233 assert_eq!(auto.compute_target_parallelism(1), 1);
235 assert_eq!(auto.compute_target_parallelism(10), 10);
236 assert_eq!(full.compute_target_parallelism(5), 5);
237 assert_eq!(full.compute_target_parallelism(8), 8);
238
239 assert_eq!(auto.compute_target_parallelism(usize::MAX), usize::MAX);
241 }
242
243 #[test]
244 fn test_bounded_strategy() {
245 let bounded_8 = AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(8).unwrap());
246 let bounded_1 = AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(1).unwrap());
247
248 assert_eq!(bounded_8.compute_target_parallelism(5), 5);
250 assert_eq!(bounded_8.compute_target_parallelism(8), 8);
252 assert_eq!(bounded_8.compute_target_parallelism(10), 8);
254 assert_eq!(bounded_1.compute_target_parallelism(1), 1);
256 assert_eq!(bounded_1.compute_target_parallelism(2), 1);
257 }
258
259 #[test]
260 fn test_ratio_strategy() {
261 let ratio_half = AdaptiveParallelismStrategy::Ratio(0.5);
262 let ratio_30pct = AdaptiveParallelismStrategy::Ratio(0.3);
263 let ratio_full = AdaptiveParallelismStrategy::Ratio(1.0);
264
265 assert_eq!(ratio_half.compute_target_parallelism(4), 2);
267 assert_eq!(ratio_half.compute_target_parallelism(5), 2);
268
269 assert_eq!(ratio_30pct.compute_target_parallelism(3), 1);
271 assert_eq!(ratio_30pct.compute_target_parallelism(4), 1);
272 assert_eq!(ratio_30pct.compute_target_parallelism(5), 1);
273 assert_eq!(ratio_30pct.compute_target_parallelism(7), 2);
274
275 assert_eq!(ratio_full.compute_target_parallelism(5), 5);
277 }
278
279 #[test]
280 fn test_edge_cases() {
281 let ratio_overflow = AdaptiveParallelismStrategy::Ratio(2.5);
282 assert_eq!(ratio_overflow.compute_target_parallelism(4), 10);
283
284 let max_parallelism =
285 AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(usize::MAX).unwrap());
286 assert_eq!(
287 max_parallelism.compute_target_parallelism(usize::MAX),
288 usize::MAX
289 );
290 }
291
292 #[test]
295 fn test_serde_string_format() {
296 let auto: AdaptiveParallelismStrategy = serde_json::from_str(r#""Auto""#).unwrap();
298 assert_eq!(auto, AdaptiveParallelismStrategy::Auto);
299
300 let full: AdaptiveParallelismStrategy = serde_json::from_str(r#""Full""#).unwrap();
301 assert_eq!(full, AdaptiveParallelismStrategy::Full);
302
303 let bounded: AdaptiveParallelismStrategy =
304 serde_json::from_str(r#""Bounded(64)""#).unwrap();
305 assert!(matches!(bounded, AdaptiveParallelismStrategy::Bounded(n) if n.get() == 64));
306
307 let ratio: AdaptiveParallelismStrategy = serde_json::from_str(r#""Ratio(0.5)""#).unwrap();
308 assert!(
309 matches!(ratio, AdaptiveParallelismStrategy::Ratio(r) if (r - 0.5).abs() < f32::EPSILON)
310 );
311
312 let auto = AdaptiveParallelismStrategy::Auto;
314 assert_eq!(serde_json::to_string(&auto).unwrap(), r#""AUTO""#);
315
316 let bounded = AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(64).unwrap());
317 assert_eq!(serde_json::to_string(&bounded).unwrap(), r#""BOUNDED(64)""#);
318
319 let original = AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(128).unwrap());
321 let serialized = serde_json::to_string(&original).unwrap();
322 let deserialized: AdaptiveParallelismStrategy = serde_json::from_str(&serialized).unwrap();
323 assert_eq!(original, deserialized);
324
325 assert!(serde_json::from_str::<AdaptiveParallelismStrategy>(r#""Ratio(-1)""#).is_err());
327 assert!(serde_json::from_str::<AdaptiveParallelismStrategy>(r#""Ratio(-0.a)""#).is_err());
328 assert!(serde_json::from_str::<AdaptiveParallelismStrategy>(r#""Bounded(0)""#).is_err());
329 assert!(serde_json::from_str::<AdaptiveParallelismStrategy>(r#""Invalid""#).is_err());
330 }
331}