risingwave_common/session_config/
parallelism.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::num::{NonZeroU64, NonZeroUsize, ParseIntError};
16use std::str::FromStr;
17
18use risingwave_common::system_param::adaptive_parallelism_strategy::{
19    AdaptiveParallelismStrategy, ParallelismStrategyParseError, parse_strategy,
20};
21
22const KEYWORD_DEFAULT: &str = "default";
23const KEYWORD_ADAPTIVE: &str = "adaptive";
24const KEYWORD_AUTO: &str = "auto";
25const KEYWORD_DEFAULT_STRATEGY: &str = "default";
26
27#[derive(Copy, Default, Debug, Clone, PartialEq, Eq)]
28pub enum ConfigParallelism {
29    #[default]
30    Default,
31    Fixed(NonZeroU64),
32    Adaptive,
33}
34
35impl FromStr for ConfigParallelism {
36    type Err = ParseIntError;
37
38    fn from_str(s: &str) -> Result<Self, Self::Err> {
39        match s.to_lowercase().as_str() {
40            KEYWORD_DEFAULT => Ok(ConfigParallelism::Default),
41            KEYWORD_ADAPTIVE | KEYWORD_AUTO => Ok(ConfigParallelism::Adaptive),
42            s => {
43                let parsed = s.parse::<u64>()?;
44                if parsed == 0 {
45                    Ok(ConfigParallelism::Adaptive)
46                } else {
47                    Ok(ConfigParallelism::Fixed(NonZeroU64::new(parsed).unwrap()))
48                }
49            }
50        }
51    }
52}
53
54impl std::fmt::Display for ConfigParallelism {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        match *self {
57            ConfigParallelism::Adaptive => {
58                write!(f, "{}", KEYWORD_ADAPTIVE)
59            }
60            ConfigParallelism::Default => {
61                write!(f, "{}", KEYWORD_DEFAULT)
62            }
63            ConfigParallelism::Fixed(n) => {
64                write!(f, "{}", n)
65            }
66        }
67    }
68}
69
70#[derive(Copy, Default, Debug, Clone, PartialEq)]
71pub enum ConfigAdaptiveParallelismStrategy {
72    #[default]
73    Default,
74    Auto,
75    Full,
76    Bounded(NonZeroU64),
77    Ratio(f32),
78}
79
80impl FromStr for ConfigAdaptiveParallelismStrategy {
81    type Err = ParallelismStrategyParseError;
82
83    fn from_str(s: &str) -> Result<Self, Self::Err> {
84        if s.eq_ignore_ascii_case(KEYWORD_DEFAULT_STRATEGY) {
85            return Ok(Self::Default);
86        }
87
88        let strategy = parse_strategy(s)?;
89        Ok(strategy.into())
90    }
91}
92
93impl From<AdaptiveParallelismStrategy> for ConfigAdaptiveParallelismStrategy {
94    fn from(value: AdaptiveParallelismStrategy) -> Self {
95        match value {
96            AdaptiveParallelismStrategy::Auto => Self::Auto,
97            AdaptiveParallelismStrategy::Full => Self::Full,
98            AdaptiveParallelismStrategy::Bounded(n) => {
99                // Safe to unwrap since `n` is non-zero.
100                Self::Bounded(NonZeroU64::new(n.get() as u64).unwrap())
101            }
102            AdaptiveParallelismStrategy::Ratio(r) => Self::Ratio(r),
103        }
104    }
105}
106
107impl From<ConfigAdaptiveParallelismStrategy> for Option<AdaptiveParallelismStrategy> {
108    fn from(value: ConfigAdaptiveParallelismStrategy) -> Self {
109        match value {
110            ConfigAdaptiveParallelismStrategy::Default => None,
111            ConfigAdaptiveParallelismStrategy::Auto => Some(AdaptiveParallelismStrategy::Auto),
112            ConfigAdaptiveParallelismStrategy::Full => Some(AdaptiveParallelismStrategy::Full),
113            ConfigAdaptiveParallelismStrategy::Bounded(n) => {
114                Some(AdaptiveParallelismStrategy::Bounded(
115                    NonZeroUsize::new(n.get() as usize)
116                        // Bounded strategy requires non-zero; `NonZeroU64` guarantees this.
117                        .unwrap(),
118                ))
119            }
120            ConfigAdaptiveParallelismStrategy::Ratio(r) => {
121                Some(AdaptiveParallelismStrategy::Ratio(r))
122            }
123        }
124    }
125}
126
127impl std::fmt::Display for ConfigAdaptiveParallelismStrategy {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        match self {
130            ConfigAdaptiveParallelismStrategy::Default => {
131                write!(f, "{}", KEYWORD_DEFAULT_STRATEGY)
132            }
133            ConfigAdaptiveParallelismStrategy::Auto => AdaptiveParallelismStrategy::Auto.fmt(f),
134            ConfigAdaptiveParallelismStrategy::Full => AdaptiveParallelismStrategy::Full.fmt(f),
135            ConfigAdaptiveParallelismStrategy::Bounded(n) => {
136                AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(n.get() as usize).unwrap())
137                    .fmt(f)
138            }
139            ConfigAdaptiveParallelismStrategy::Ratio(r) => {
140                AdaptiveParallelismStrategy::Ratio(*r).fmt(f)
141            }
142        }
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn test_strategy_parse_default() {
152        assert_eq!(
153            "default"
154                .parse::<ConfigAdaptiveParallelismStrategy>()
155                .unwrap(),
156            ConfigAdaptiveParallelismStrategy::Default
157        );
158    }
159
160    #[test]
161    fn test_strategy_parse_ratio() {
162        let strategy: ConfigAdaptiveParallelismStrategy = "Ratio(0.5)".parse().unwrap();
163        assert_eq!(strategy, ConfigAdaptiveParallelismStrategy::Ratio(0.5));
164    }
165
166    #[test]
167    fn test_strategy_into_option() {
168        let opt: Option<AdaptiveParallelismStrategy> =
169            ConfigAdaptiveParallelismStrategy::Full.into();
170        assert_eq!(opt, Some(AdaptiveParallelismStrategy::Full));
171    }
172}