risingwave_frontend/stream_fragmenter/
parallelism.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::session_config::parallelism::{
16    ConfigAdaptiveParallelismStrategy, ConfigParallelism,
17};
18use risingwave_common::system_param::AdaptiveParallelismStrategy;
19use risingwave_pb::stream_plan::stream_fragment_graph::Parallelism;
20
21pub(crate) fn derive_parallelism(
22    specific_type_parallelism: Option<ConfigParallelism>,
23    global_streaming_parallelism: ConfigParallelism,
24) -> Option<Parallelism> {
25    match specific_type_parallelism {
26        // fallback to global streaming_parallelism
27        Some(ConfigParallelism::Default) | None => match global_streaming_parallelism {
28            // for streaming_parallelism, `Default` is `Adaptive`
29            ConfigParallelism::Default | ConfigParallelism::Adaptive => None,
30            ConfigParallelism::Fixed(n) => Some(Parallelism {
31                parallelism: n.get(),
32            }),
33        },
34
35        // specific type parallelism is set to `Adaptive` or `Fixed(0)`
36        Some(ConfigParallelism::Adaptive) => None,
37
38        // specific type parallelism is set to `Fixed(n)
39        Some(ConfigParallelism::Fixed(n)) => Some(Parallelism {
40            parallelism: n.get(),
41        }),
42    }
43}
44
45pub(crate) fn derive_parallelism_strategy(
46    specific_strategy: Option<ConfigAdaptiveParallelismStrategy>,
47    global_strategy: ConfigAdaptiveParallelismStrategy,
48) -> Option<AdaptiveParallelismStrategy> {
49    let to_strategy =
50        |cfg: ConfigAdaptiveParallelismStrategy| -> Option<AdaptiveParallelismStrategy> {
51            cfg.into()
52        };
53
54    match specific_strategy.unwrap_or(ConfigAdaptiveParallelismStrategy::Default) {
55        ConfigAdaptiveParallelismStrategy::Default => to_strategy(global_strategy),
56        other => to_strategy(other),
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use std::num::NonZeroU64;
63
64    use super::*;
65
66    #[test]
67    fn test_none_global_fixed() {
68        let global = ConfigParallelism::Fixed(NonZeroU64::new(4).unwrap());
69        assert_eq!(
70            derive_parallelism(None, global).map(|p| p.parallelism),
71            Some(4)
72        );
73    }
74
75    #[test]
76    fn test_none_global_default() {
77        let global = ConfigParallelism::Default;
78        assert_eq!(derive_parallelism(None, global), None);
79    }
80
81    #[test]
82    fn test_none_global_adaptive() {
83        let global = ConfigParallelism::Adaptive;
84        assert_eq!(derive_parallelism(None, global), None);
85    }
86
87    #[test]
88    fn test_default_global_fixed() {
89        let specific = Some(ConfigParallelism::Default);
90        let global = ConfigParallelism::Fixed(NonZeroU64::new(2).unwrap());
91        assert_eq!(
92            derive_parallelism(specific, global).map(|p| p.parallelism),
93            Some(2)
94        );
95    }
96
97    #[test]
98    fn test_default_global_default() {
99        let specific = Some(ConfigParallelism::Default);
100        let global = ConfigParallelism::Default;
101        assert_eq!(derive_parallelism(specific, global), None);
102    }
103
104    #[test]
105    fn test_default_global_adaptive() {
106        let specific = Some(ConfigParallelism::Default);
107        let global = ConfigParallelism::Adaptive;
108        assert_eq!(derive_parallelism(specific, global), None);
109    }
110
111    #[test]
112    fn test_adaptive_any_global() {
113        let specific = Some(ConfigParallelism::Adaptive);
114        let globals = [
115            ConfigParallelism::Default,
116            ConfigParallelism::Adaptive,
117            ConfigParallelism::Fixed(NonZeroU64::new(8).unwrap()),
118        ];
119
120        for global in globals {
121            assert_eq!(derive_parallelism(specific, global), None);
122        }
123    }
124
125    #[test]
126    fn test_fixed_override_global() {
127        let specific = Some(ConfigParallelism::Fixed(NonZeroU64::new(6).unwrap()));
128        let globals = [
129            ConfigParallelism::Default,
130            ConfigParallelism::Adaptive,
131            ConfigParallelism::Fixed(NonZeroU64::new(3).unwrap()),
132        ];
133
134        for global in globals {
135            assert_eq!(
136                derive_parallelism(specific, global).map(|p| p.parallelism),
137                Some(6)
138            );
139        }
140    }
141
142    #[test]
143    fn test_parallelism_strategy_fallback() {
144        assert_eq!(
145            derive_parallelism_strategy(None, ConfigAdaptiveParallelismStrategy::Auto),
146            Some(AdaptiveParallelismStrategy::Auto)
147        );
148        assert_eq!(
149            derive_parallelism_strategy(
150                Some(ConfigAdaptiveParallelismStrategy::Default),
151                ConfigAdaptiveParallelismStrategy::Full
152            ),
153            Some(AdaptiveParallelismStrategy::Full)
154        );
155    }
156
157    #[test]
158    fn test_parallelism_strategy_override() {
159        assert_eq!(
160            derive_parallelism_strategy(
161                Some(ConfigAdaptiveParallelismStrategy::Ratio(0.5)),
162                ConfigAdaptiveParallelismStrategy::Full
163            ),
164            Some(AdaptiveParallelismStrategy::Ratio(0.5))
165        );
166    }
167}