risingwave_common/system_param/
adaptive_parallelism_strategy.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::{max, min};
16use std::fmt::{Display, Formatter};
17use std::num::NonZeroUsize;
18use std::str::FromStr;
19
20use regex::Regex;
21use risingwave_common::system_param::ParamValue;
22use serde::{Deserialize, Serialize};
23use thiserror::Error;
24
25/// Use `#[serde(try_from, into)]` to serialize/deserialize as string format (e.g., "Bounded(64)"),
26/// which is consistent with `ALTER SYSTEM SET` command.
27#[derive(PartialEq, Copy, Clone, Debug, Default, Serialize, Deserialize)]
28#[serde(try_from = "String", into = "String")]
29pub enum AdaptiveParallelismStrategy {
30    #[default]
31    Auto,
32    Full,
33    Bounded(NonZeroUsize),
34    Ratio(f32),
35}
36
37impl TryFrom<String> for AdaptiveParallelismStrategy {
38    type Error = ParallelismStrategyParseError;
39
40    fn try_from(s: String) -> Result<Self, Self::Error> {
41        s.parse()
42    }
43}
44
45impl Display for AdaptiveParallelismStrategy {
46    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
47        match self {
48            AdaptiveParallelismStrategy::Auto => write!(f, "AUTO"),
49            AdaptiveParallelismStrategy::Full => write!(f, "FULL"),
50            AdaptiveParallelismStrategy::Bounded(n) => write!(f, "BOUNDED({})", n),
51            AdaptiveParallelismStrategy::Ratio(r) => write!(f, "RATIO({})", r),
52        }
53    }
54}
55
56impl From<AdaptiveParallelismStrategy> for String {
57    fn from(val: AdaptiveParallelismStrategy) -> Self {
58        val.to_string()
59    }
60}
61
62#[derive(Error, Debug)]
63pub enum ParallelismStrategyParseError {
64    #[error("Unsupported strategy: {0}")]
65    UnsupportedStrategy(String),
66
67    #[error("Parse error: {0}")]
68    ParseIntError(#[from] std::num::ParseIntError),
69
70    #[error("Parse error: {0}")]
71    ParseFloatError(#[from] std::num::ParseFloatError),
72
73    #[error("Invalid value for Bounded strategy: must be positive integer")]
74    InvalidBoundedValue,
75
76    #[error("Invalid value for Ratio strategy: must be between 0.0 and 1.0")]
77    InvalidRatioValue,
78}
79
80impl AdaptiveParallelismStrategy {
81    pub fn compute_target_parallelism(&self, current_parallelism: usize) -> usize {
82        match self {
83            AdaptiveParallelismStrategy::Auto | AdaptiveParallelismStrategy::Full => {
84                current_parallelism
85            }
86            AdaptiveParallelismStrategy::Bounded(n) => min(n.get(), current_parallelism),
87            AdaptiveParallelismStrategy::Ratio(r) => {
88                max((current_parallelism as f32 * r).floor() as usize, 1)
89            }
90        }
91    }
92}
93
94pub fn parse_strategy(
95    input: &str,
96) -> Result<AdaptiveParallelismStrategy, ParallelismStrategyParseError> {
97    let lower_input = input.to_lowercase();
98
99    // Handle Auto/Full case-insensitively without regex
100    match lower_input.as_str() {
101        "auto" => return Ok(AdaptiveParallelismStrategy::Auto),
102        "full" => return Ok(AdaptiveParallelismStrategy::Full),
103        _ => (),
104    }
105
106    // Compile regex patterns once using OnceLock
107    fn bounded_re() -> &'static Regex {
108        static RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
109        RE.get_or_init(|| Regex::new(r"(?i)^bounded\((?<value>\d+)\)$").unwrap())
110    }
111
112    fn ratio_re() -> &'static Regex {
113        static RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
114        RE.get_or_init(|| Regex::new(r"(?i)^ratio\((?<value>[+-]?\d+(?:\.\d+)?)\)$").unwrap())
115    }
116
117    // Try to match Bounded pattern
118    if let Some(caps) = bounded_re().captures(&lower_input) {
119        let value_str = caps.name("value").unwrap().as_str();
120        let value: usize = value_str.parse()?;
121
122        let value =
123            NonZeroUsize::new(value).ok_or(ParallelismStrategyParseError::InvalidBoundedValue)?;
124
125        return Ok(AdaptiveParallelismStrategy::Bounded(value));
126    }
127
128    // Try to match Ratio pattern
129    if let Some(caps) = ratio_re().captures(&lower_input) {
130        let value_str = caps.name("value").unwrap().as_str();
131        let value: f32 = value_str.parse()?;
132
133        if !(0.0..=1.0).contains(&value) {
134            return Err(ParallelismStrategyParseError::InvalidRatioValue);
135        }
136
137        return Ok(AdaptiveParallelismStrategy::Ratio(value));
138    }
139
140    // If no patterns matched
141    Err(ParallelismStrategyParseError::UnsupportedStrategy(
142        input.to_owned(),
143    ))
144}
145
146impl FromStr for AdaptiveParallelismStrategy {
147    type Err = ParallelismStrategyParseError;
148
149    fn from_str(s: &str) -> Result<Self, Self::Err> {
150        parse_strategy(s)
151    }
152}
153
154impl ParamValue for AdaptiveParallelismStrategy {
155    type Borrowed<'a> = AdaptiveParallelismStrategy;
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_valid_strategies() {
164        assert_eq!(
165            parse_strategy("Auto").unwrap(),
166            AdaptiveParallelismStrategy::Auto
167        );
168        assert_eq!(
169            parse_strategy("FULL").unwrap(),
170            AdaptiveParallelismStrategy::Full
171        );
172
173        let bounded = parse_strategy("Bounded(42)").unwrap();
174        assert!(matches!(bounded, AdaptiveParallelismStrategy::Bounded(n) if n.get() == 42));
175
176        let ratio = parse_strategy("Ratio(0.75)").unwrap();
177        assert!(matches!(ratio, AdaptiveParallelismStrategy::Ratio(0.75)));
178    }
179
180    #[test]
181    fn test_invalid_values() {
182        // Bounded(0) is invalid - must be positive
183        assert!(matches!(
184            parse_strategy("Bounded(0)"),
185            Err(ParallelismStrategyParseError::InvalidBoundedValue)
186        ));
187
188        // Ratio out of range [0.0, 1.0]
189        assert!(matches!(
190            parse_strategy("Ratio(1.1)"),
191            Err(ParallelismStrategyParseError::InvalidRatioValue)
192        ));
193        assert!(matches!(
194            parse_strategy("Ratio(-0.5)"),
195            Err(ParallelismStrategyParseError::InvalidRatioValue)
196        ));
197        assert!(matches!(
198            parse_strategy("Ratio(-1)"),
199            Err(ParallelismStrategyParseError::InvalidRatioValue)
200        ));
201
202        // Invalid number format - regex won't match
203        assert!(matches!(
204            parse_strategy("Ratio(-0.a)"),
205            Err(ParallelismStrategyParseError::UnsupportedStrategy(_))
206        ));
207
208        // Negative bounded - regex won't match (only \d+ allowed)
209        assert!(matches!(
210            parse_strategy("Bounded(-5)"),
211            Err(ParallelismStrategyParseError::UnsupportedStrategy(_))
212        ));
213    }
214
215    #[test]
216    fn test_unsupported_formats() {
217        assert!(matches!(
218            parse_strategy("Invalid"),
219            Err(ParallelismStrategyParseError::UnsupportedStrategy(_))
220        ));
221
222        assert!(matches!(
223            parse_strategy("Auto(5)"),
224            Err(ParallelismStrategyParseError::UnsupportedStrategy(_))
225        ));
226    }
227
228    #[test]
229    fn test_auto_full_strategies() {
230        let auto = AdaptiveParallelismStrategy::Auto;
231        let full = AdaptiveParallelismStrategy::Full;
232
233        // Basic cases
234        assert_eq!(auto.compute_target_parallelism(1), 1);
235        assert_eq!(auto.compute_target_parallelism(10), 10);
236        assert_eq!(full.compute_target_parallelism(5), 5);
237        assert_eq!(full.compute_target_parallelism(8), 8);
238
239        // Edge cases
240        assert_eq!(auto.compute_target_parallelism(usize::MAX), usize::MAX);
241    }
242
243    #[test]
244    fn test_bounded_strategy() {
245        let bounded_8 = AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(8).unwrap());
246        let bounded_1 = AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(1).unwrap());
247
248        // Below bound
249        assert_eq!(bounded_8.compute_target_parallelism(5), 5);
250        // Exactly at bound
251        assert_eq!(bounded_8.compute_target_parallelism(8), 8);
252        // Above bound
253        assert_eq!(bounded_8.compute_target_parallelism(10), 8);
254        // Minimum bound
255        assert_eq!(bounded_1.compute_target_parallelism(1), 1);
256        assert_eq!(bounded_1.compute_target_parallelism(2), 1);
257    }
258
259    #[test]
260    fn test_ratio_strategy() {
261        let ratio_half = AdaptiveParallelismStrategy::Ratio(0.5);
262        let ratio_30pct = AdaptiveParallelismStrategy::Ratio(0.3);
263        let ratio_full = AdaptiveParallelismStrategy::Ratio(1.0);
264
265        // Normal calculations
266        assert_eq!(ratio_half.compute_target_parallelism(4), 2);
267        assert_eq!(ratio_half.compute_target_parallelism(5), 2);
268
269        // Flooring behavior
270        assert_eq!(ratio_30pct.compute_target_parallelism(3), 1);
271        assert_eq!(ratio_30pct.compute_target_parallelism(4), 1);
272        assert_eq!(ratio_30pct.compute_target_parallelism(5), 1);
273        assert_eq!(ratio_30pct.compute_target_parallelism(7), 2);
274
275        // Full ratio
276        assert_eq!(ratio_full.compute_target_parallelism(5), 5);
277    }
278
279    #[test]
280    fn test_edge_cases() {
281        let ratio_overflow = AdaptiveParallelismStrategy::Ratio(2.5);
282        assert_eq!(ratio_overflow.compute_target_parallelism(4), 10);
283
284        let max_parallelism =
285            AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(usize::MAX).unwrap());
286        assert_eq!(
287            max_parallelism.compute_target_parallelism(usize::MAX),
288            usize::MAX
289        );
290    }
291
292    /// Test serde serialization/deserialization uses string format,
293    /// which is consistent with `ALTER SYSTEM SET` command.
294    #[test]
295    fn test_serde_string_format() {
296        // Test deserialization from string (as used in config files)
297        let auto: AdaptiveParallelismStrategy = serde_json::from_str(r#""Auto""#).unwrap();
298        assert_eq!(auto, AdaptiveParallelismStrategy::Auto);
299
300        let full: AdaptiveParallelismStrategy = serde_json::from_str(r#""Full""#).unwrap();
301        assert_eq!(full, AdaptiveParallelismStrategy::Full);
302
303        let bounded: AdaptiveParallelismStrategy =
304            serde_json::from_str(r#""Bounded(64)""#).unwrap();
305        assert!(matches!(bounded, AdaptiveParallelismStrategy::Bounded(n) if n.get() == 64));
306
307        let ratio: AdaptiveParallelismStrategy = serde_json::from_str(r#""Ratio(0.5)""#).unwrap();
308        assert!(
309            matches!(ratio, AdaptiveParallelismStrategy::Ratio(r) if (r - 0.5).abs() < f32::EPSILON)
310        );
311
312        // Test serialization to string
313        let auto = AdaptiveParallelismStrategy::Auto;
314        assert_eq!(serde_json::to_string(&auto).unwrap(), r#""AUTO""#);
315
316        let bounded = AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(64).unwrap());
317        assert_eq!(serde_json::to_string(&bounded).unwrap(), r#""BOUNDED(64)""#);
318
319        // Test roundtrip
320        let original = AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(128).unwrap());
321        let serialized = serde_json::to_string(&original).unwrap();
322        let deserialized: AdaptiveParallelismStrategy = serde_json::from_str(&serialized).unwrap();
323        assert_eq!(original, deserialized);
324
325        // Test deserialization errors (same validation as ALTER SYSTEM SET)
326        assert!(serde_json::from_str::<AdaptiveParallelismStrategy>(r#""Ratio(-1)""#).is_err());
327        assert!(serde_json::from_str::<AdaptiveParallelismStrategy>(r#""Ratio(-0.a)""#).is_err());
328        assert!(serde_json::from_str::<AdaptiveParallelismStrategy>(r#""Bounded(0)""#).is_err());
329        assert!(serde_json::from_str::<AdaptiveParallelismStrategy>(r#""Invalid""#).is_err());
330    }
331}