risingwave_common/system_param/
adaptive_parallelism_strategy.rs1use std::cmp::{max, min};
16use std::fmt::{Display, Formatter};
17use std::num::NonZeroUsize;
18use std::str::FromStr;
19
20use regex::Regex;
21use risingwave_common::system_param::ParamValue;
22use serde::{Deserialize, Serialize};
23use thiserror::Error;
24
25#[derive(PartialEq, Copy, Clone, Debug, Serialize, Deserialize, Default)]
26pub enum AdaptiveParallelismStrategy {
27 #[default]
28 Auto,
29 Full,
30 Bounded(NonZeroUsize),
31 Ratio(f32),
32}
33
34impl Display for AdaptiveParallelismStrategy {
35 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36 match self {
37 AdaptiveParallelismStrategy::Auto => write!(f, "AUTO"),
38 AdaptiveParallelismStrategy::Full => write!(f, "FULL"),
39 AdaptiveParallelismStrategy::Bounded(n) => write!(f, "BOUNDED({})", n),
40 AdaptiveParallelismStrategy::Ratio(r) => write!(f, "RATIO({})", r),
41 }
42 }
43}
44
45impl From<AdaptiveParallelismStrategy> for String {
46 fn from(val: AdaptiveParallelismStrategy) -> Self {
47 val.to_string()
48 }
49}
50
51#[derive(Error, Debug)]
52pub enum ParallelismStrategyParseError {
53 #[error("Unsupported strategy: {0}")]
54 UnsupportedStrategy(String),
55
56 #[error("Parse error: {0}")]
57 ParseIntError(#[from] std::num::ParseIntError),
58
59 #[error("Parse error: {0}")]
60 ParseFloatError(#[from] std::num::ParseFloatError),
61
62 #[error("Invalid value for Bounded strategy: must be positive integer")]
63 InvalidBoundedValue,
64
65 #[error("Invalid value for Ratio strategy: must be between 0.0 and 1.0")]
66 InvalidRatioValue,
67}
68
69impl AdaptiveParallelismStrategy {
70 pub fn compute_target_parallelism(&self, current_parallelism: usize) -> usize {
71 match self {
72 AdaptiveParallelismStrategy::Auto | AdaptiveParallelismStrategy::Full => {
73 current_parallelism
74 }
75 AdaptiveParallelismStrategy::Bounded(n) => min(n.get(), current_parallelism),
76 AdaptiveParallelismStrategy::Ratio(r) => {
77 max((current_parallelism as f32 * r).floor() as usize, 1)
78 }
79 }
80 }
81}
82
83pub fn parse_strategy(
84 input: &str,
85) -> Result<AdaptiveParallelismStrategy, ParallelismStrategyParseError> {
86 let lower_input = input.to_lowercase();
87
88 match lower_input.as_str() {
90 "auto" => return Ok(AdaptiveParallelismStrategy::Auto),
91 "full" => return Ok(AdaptiveParallelismStrategy::Full),
92 _ => (),
93 }
94
95 fn bounded_re() -> &'static Regex {
97 static RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
98 RE.get_or_init(|| Regex::new(r"(?i)^bounded\((?<value>\d+)\)$").unwrap())
99 }
100
101 fn ratio_re() -> &'static Regex {
102 static RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
103 RE.get_or_init(|| Regex::new(r"(?i)^ratio\((?<value>[+-]?\d+(?:\.\d+)?)\)$").unwrap())
104 }
105
106 if let Some(caps) = bounded_re().captures(&lower_input) {
108 let value_str = caps.name("value").unwrap().as_str();
109 let value: usize = value_str.parse()?;
110
111 let value =
112 NonZeroUsize::new(value).ok_or(ParallelismStrategyParseError::InvalidBoundedValue)?;
113
114 return Ok(AdaptiveParallelismStrategy::Bounded(value));
115 }
116
117 if let Some(caps) = ratio_re().captures(&lower_input) {
119 let value_str = caps.name("value").unwrap().as_str();
120 let value: f32 = value_str.parse()?;
121
122 if !(0.0..=1.0).contains(&value) {
123 return Err(ParallelismStrategyParseError::InvalidRatioValue);
124 }
125
126 return Ok(AdaptiveParallelismStrategy::Ratio(value));
127 }
128
129 Err(ParallelismStrategyParseError::UnsupportedStrategy(
131 input.to_owned(),
132 ))
133}
134
135impl FromStr for AdaptiveParallelismStrategy {
136 type Err = ParallelismStrategyParseError;
137
138 fn from_str(s: &str) -> Result<Self, Self::Err> {
139 parse_strategy(s)
140 }
141}
142
143impl ParamValue for AdaptiveParallelismStrategy {
144 type Borrowed<'a> = AdaptiveParallelismStrategy;
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn test_valid_strategies() {
153 assert_eq!(
154 parse_strategy("Auto").unwrap(),
155 AdaptiveParallelismStrategy::Auto
156 );
157 assert_eq!(
158 parse_strategy("FULL").unwrap(),
159 AdaptiveParallelismStrategy::Full
160 );
161
162 let bounded = parse_strategy("Bounded(42)").unwrap();
163 assert!(matches!(bounded, AdaptiveParallelismStrategy::Bounded(n) if n.get() == 42));
164
165 let ratio = parse_strategy("Ratio(0.75)").unwrap();
166 assert!(matches!(ratio, AdaptiveParallelismStrategy::Ratio(0.75)));
167 }
168
169 #[test]
170 fn test_invalid_values() {
171 assert!(matches!(
172 parse_strategy("Bounded(0)"),
173 Err(ParallelismStrategyParseError::InvalidBoundedValue)
174 ));
175
176 assert!(matches!(
177 parse_strategy("Ratio(1.1)"),
178 Err(ParallelismStrategyParseError::InvalidRatioValue)
179 ));
180
181 assert!(matches!(
182 parse_strategy("Ratio(-0.5)"),
183 Err(ParallelismStrategyParseError::InvalidRatioValue)
184 ));
185 }
186
187 #[test]
188 fn test_unsupported_formats() {
189 assert!(matches!(
190 parse_strategy("Invalid"),
191 Err(ParallelismStrategyParseError::UnsupportedStrategy(_))
192 ));
193
194 assert!(matches!(
195 parse_strategy("Auto(5)"),
196 Err(ParallelismStrategyParseError::UnsupportedStrategy(_))
197 ));
198 }
199
200 #[test]
201 fn test_auto_full_strategies() {
202 let auto = AdaptiveParallelismStrategy::Auto;
203 let full = AdaptiveParallelismStrategy::Full;
204
205 assert_eq!(auto.compute_target_parallelism(1), 1);
207 assert_eq!(auto.compute_target_parallelism(10), 10);
208 assert_eq!(full.compute_target_parallelism(5), 5);
209 assert_eq!(full.compute_target_parallelism(8), 8);
210
211 assert_eq!(auto.compute_target_parallelism(usize::MAX), usize::MAX);
213 }
214
215 #[test]
216 fn test_bounded_strategy() {
217 let bounded_8 = AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(8).unwrap());
218 let bounded_1 = AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(1).unwrap());
219
220 assert_eq!(bounded_8.compute_target_parallelism(5), 5);
222 assert_eq!(bounded_8.compute_target_parallelism(8), 8);
224 assert_eq!(bounded_8.compute_target_parallelism(10), 8);
226 assert_eq!(bounded_1.compute_target_parallelism(1), 1);
228 assert_eq!(bounded_1.compute_target_parallelism(2), 1);
229 }
230
231 #[test]
232 fn test_ratio_strategy() {
233 let ratio_half = AdaptiveParallelismStrategy::Ratio(0.5);
234 let ratio_30pct = AdaptiveParallelismStrategy::Ratio(0.3);
235 let ratio_full = AdaptiveParallelismStrategy::Ratio(1.0);
236
237 assert_eq!(ratio_half.compute_target_parallelism(4), 2);
239 assert_eq!(ratio_half.compute_target_parallelism(5), 2);
240
241 assert_eq!(ratio_30pct.compute_target_parallelism(3), 1);
243 assert_eq!(ratio_30pct.compute_target_parallelism(4), 1);
244 assert_eq!(ratio_30pct.compute_target_parallelism(5), 1);
245 assert_eq!(ratio_30pct.compute_target_parallelism(7), 2);
246
247 assert_eq!(ratio_full.compute_target_parallelism(5), 5);
249 }
250
251 #[test]
252 fn test_edge_cases() {
253 let ratio_overflow = AdaptiveParallelismStrategy::Ratio(2.5);
254 assert_eq!(ratio_overflow.compute_target_parallelism(4), 10);
255
256 let max_parallelism =
257 AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(usize::MAX).unwrap());
258 assert_eq!(
259 max_parallelism.compute_target_parallelism(usize::MAX),
260 usize::MAX
261 );
262 }
263}