risingwave_common/system_param/
adaptive_parallelism_strategy.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::{max, min};
16use std::fmt::{Display, Formatter};
17use std::num::NonZeroUsize;
18use std::str::FromStr;
19
20use regex::Regex;
21use risingwave_common::system_param::ParamValue;
22use serde::{Deserialize, Serialize};
23use thiserror::Error;
24
25#[derive(PartialEq, Copy, Clone, Debug, Serialize, Deserialize, Default)]
26pub enum AdaptiveParallelismStrategy {
27    #[default]
28    Auto,
29    Full,
30    Bounded(NonZeroUsize),
31    Ratio(f32),
32}
33
34impl Display for AdaptiveParallelismStrategy {
35    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36        match self {
37            AdaptiveParallelismStrategy::Auto => write!(f, "AUTO"),
38            AdaptiveParallelismStrategy::Full => write!(f, "FULL"),
39            AdaptiveParallelismStrategy::Bounded(n) => write!(f, "BOUNDED({})", n),
40            AdaptiveParallelismStrategy::Ratio(r) => write!(f, "RATIO({})", r),
41        }
42    }
43}
44
45impl From<AdaptiveParallelismStrategy> for String {
46    fn from(val: AdaptiveParallelismStrategy) -> Self {
47        val.to_string()
48    }
49}
50
51#[derive(Error, Debug)]
52pub enum ParallelismStrategyParseError {
53    #[error("Unsupported strategy: {0}")]
54    UnsupportedStrategy(String),
55
56    #[error("Parse error: {0}")]
57    ParseIntError(#[from] std::num::ParseIntError),
58
59    #[error("Parse error: {0}")]
60    ParseFloatError(#[from] std::num::ParseFloatError),
61
62    #[error("Invalid value for Bounded strategy: must be positive integer")]
63    InvalidBoundedValue,
64
65    #[error("Invalid value for Ratio strategy: must be between 0.0 and 1.0")]
66    InvalidRatioValue,
67}
68
69impl AdaptiveParallelismStrategy {
70    pub fn compute_target_parallelism(&self, current_parallelism: usize) -> usize {
71        match self {
72            AdaptiveParallelismStrategy::Auto | AdaptiveParallelismStrategy::Full => {
73                current_parallelism
74            }
75            AdaptiveParallelismStrategy::Bounded(n) => min(n.get(), current_parallelism),
76            AdaptiveParallelismStrategy::Ratio(r) => {
77                max((current_parallelism as f32 * r).floor() as usize, 1)
78            }
79        }
80    }
81}
82
83pub fn parse_strategy(
84    input: &str,
85) -> Result<AdaptiveParallelismStrategy, ParallelismStrategyParseError> {
86    let lower_input = input.to_lowercase();
87
88    // Handle Auto/Full case-insensitively without regex
89    match lower_input.as_str() {
90        "auto" => return Ok(AdaptiveParallelismStrategy::Auto),
91        "full" => return Ok(AdaptiveParallelismStrategy::Full),
92        _ => (),
93    }
94
95    // Compile regex patterns once using OnceLock
96    fn bounded_re() -> &'static Regex {
97        static RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
98        RE.get_or_init(|| Regex::new(r"(?i)^bounded\((?<value>\d+)\)$").unwrap())
99    }
100
101    fn ratio_re() -> &'static Regex {
102        static RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
103        RE.get_or_init(|| Regex::new(r"(?i)^ratio\((?<value>[+-]?\d+(?:\.\d+)?)\)$").unwrap())
104    }
105
106    // Try to match Bounded pattern
107    if let Some(caps) = bounded_re().captures(&lower_input) {
108        let value_str = caps.name("value").unwrap().as_str();
109        let value: usize = value_str.parse()?;
110
111        let value =
112            NonZeroUsize::new(value).ok_or(ParallelismStrategyParseError::InvalidBoundedValue)?;
113
114        return Ok(AdaptiveParallelismStrategy::Bounded(value));
115    }
116
117    // Try to match Ratio pattern
118    if let Some(caps) = ratio_re().captures(&lower_input) {
119        let value_str = caps.name("value").unwrap().as_str();
120        let value: f32 = value_str.parse()?;
121
122        if !(0.0..=1.0).contains(&value) {
123            return Err(ParallelismStrategyParseError::InvalidRatioValue);
124        }
125
126        return Ok(AdaptiveParallelismStrategy::Ratio(value));
127    }
128
129    // If no patterns matched
130    Err(ParallelismStrategyParseError::UnsupportedStrategy(
131        input.to_owned(),
132    ))
133}
134
135impl FromStr for AdaptiveParallelismStrategy {
136    type Err = ParallelismStrategyParseError;
137
138    fn from_str(s: &str) -> Result<Self, Self::Err> {
139        parse_strategy(s)
140    }
141}
142
143impl ParamValue for AdaptiveParallelismStrategy {
144    type Borrowed<'a> = AdaptiveParallelismStrategy;
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn test_valid_strategies() {
153        assert_eq!(
154            parse_strategy("Auto").unwrap(),
155            AdaptiveParallelismStrategy::Auto
156        );
157        assert_eq!(
158            parse_strategy("FULL").unwrap(),
159            AdaptiveParallelismStrategy::Full
160        );
161
162        let bounded = parse_strategy("Bounded(42)").unwrap();
163        assert!(matches!(bounded, AdaptiveParallelismStrategy::Bounded(n) if n.get() == 42));
164
165        let ratio = parse_strategy("Ratio(0.75)").unwrap();
166        assert!(matches!(ratio, AdaptiveParallelismStrategy::Ratio(0.75)));
167    }
168
169    #[test]
170    fn test_invalid_values() {
171        assert!(matches!(
172            parse_strategy("Bounded(0)"),
173            Err(ParallelismStrategyParseError::InvalidBoundedValue)
174        ));
175
176        assert!(matches!(
177            parse_strategy("Ratio(1.1)"),
178            Err(ParallelismStrategyParseError::InvalidRatioValue)
179        ));
180
181        assert!(matches!(
182            parse_strategy("Ratio(-0.5)"),
183            Err(ParallelismStrategyParseError::InvalidRatioValue)
184        ));
185    }
186
187    #[test]
188    fn test_unsupported_formats() {
189        assert!(matches!(
190            parse_strategy("Invalid"),
191            Err(ParallelismStrategyParseError::UnsupportedStrategy(_))
192        ));
193
194        assert!(matches!(
195            parse_strategy("Auto(5)"),
196            Err(ParallelismStrategyParseError::UnsupportedStrategy(_))
197        ));
198    }
199
200    #[test]
201    fn test_auto_full_strategies() {
202        let auto = AdaptiveParallelismStrategy::Auto;
203        let full = AdaptiveParallelismStrategy::Full;
204
205        // Basic cases
206        assert_eq!(auto.compute_target_parallelism(1), 1);
207        assert_eq!(auto.compute_target_parallelism(10), 10);
208        assert_eq!(full.compute_target_parallelism(5), 5);
209        assert_eq!(full.compute_target_parallelism(8), 8);
210
211        // Edge cases
212        assert_eq!(auto.compute_target_parallelism(usize::MAX), usize::MAX);
213    }
214
215    #[test]
216    fn test_bounded_strategy() {
217        let bounded_8 = AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(8).unwrap());
218        let bounded_1 = AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(1).unwrap());
219
220        // Below bound
221        assert_eq!(bounded_8.compute_target_parallelism(5), 5);
222        // Exactly at bound
223        assert_eq!(bounded_8.compute_target_parallelism(8), 8);
224        // Above bound
225        assert_eq!(bounded_8.compute_target_parallelism(10), 8);
226        // Minimum bound
227        assert_eq!(bounded_1.compute_target_parallelism(1), 1);
228        assert_eq!(bounded_1.compute_target_parallelism(2), 1);
229    }
230
231    #[test]
232    fn test_ratio_strategy() {
233        let ratio_half = AdaptiveParallelismStrategy::Ratio(0.5);
234        let ratio_30pct = AdaptiveParallelismStrategy::Ratio(0.3);
235        let ratio_full = AdaptiveParallelismStrategy::Ratio(1.0);
236
237        // Normal calculations
238        assert_eq!(ratio_half.compute_target_parallelism(4), 2);
239        assert_eq!(ratio_half.compute_target_parallelism(5), 2);
240
241        // Flooring behavior
242        assert_eq!(ratio_30pct.compute_target_parallelism(3), 1);
243        assert_eq!(ratio_30pct.compute_target_parallelism(4), 1);
244        assert_eq!(ratio_30pct.compute_target_parallelism(5), 1);
245        assert_eq!(ratio_30pct.compute_target_parallelism(7), 2);
246
247        // Full ratio
248        assert_eq!(ratio_full.compute_target_parallelism(5), 5);
249    }
250
251    #[test]
252    fn test_edge_cases() {
253        let ratio_overflow = AdaptiveParallelismStrategy::Ratio(2.5);
254        assert_eq!(ratio_overflow.compute_target_parallelism(4), 10);
255
256        let max_parallelism =
257            AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(usize::MAX).unwrap());
258        assert_eq!(
259            max_parallelism.compute_target_parallelism(usize::MAX),
260            usize::MAX
261        );
262    }
263}