risingwave_frontend/stream_fragmenter/
parallelism.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::session_config::parallelism::ConfigParallelism;
16use risingwave_pb::stream_plan::stream_fragment_graph::Parallelism;
17
18pub(crate) fn derive_parallelism(
19    specific_type_parallelism: Option<ConfigParallelism>,
20    global_streaming_parallelism: ConfigParallelism,
21) -> Option<Parallelism> {
22    match specific_type_parallelism {
23        // fallback to global streaming_parallelism
24        Some(ConfigParallelism::Default) | None => match global_streaming_parallelism {
25            // for streaming_parallelism, `Default` is `Adaptive`
26            ConfigParallelism::Default | ConfigParallelism::Adaptive => None,
27            ConfigParallelism::Fixed(n) => Some(Parallelism {
28                parallelism: n.get(),
29            }),
30        },
31
32        // specific type parallelism is set to `Adaptive` or `Fixed(0)`
33        Some(ConfigParallelism::Adaptive) => None,
34
35        // specific type parallelism is set to `Fixed(n)
36        Some(ConfigParallelism::Fixed(n)) => Some(Parallelism {
37            parallelism: n.get(),
38        }),
39    }
40}
41
42#[cfg(test)]
43mod tests {
44    use std::num::NonZeroU64;
45
46    use super::*;
47
48    #[test]
49    fn test_none_global_fixed() {
50        let global = ConfigParallelism::Fixed(NonZeroU64::new(4).unwrap());
51        assert_eq!(
52            derive_parallelism(None, global).map(|p| p.parallelism),
53            Some(4)
54        );
55    }
56
57    #[test]
58    fn test_none_global_default() {
59        let global = ConfigParallelism::Default;
60        assert_eq!(derive_parallelism(None, global), None);
61    }
62
63    #[test]
64    fn test_none_global_adaptive() {
65        let global = ConfigParallelism::Adaptive;
66        assert_eq!(derive_parallelism(None, global), None);
67    }
68
69    #[test]
70    fn test_default_global_fixed() {
71        let specific = Some(ConfigParallelism::Default);
72        let global = ConfigParallelism::Fixed(NonZeroU64::new(2).unwrap());
73        assert_eq!(
74            derive_parallelism(specific, global).map(|p| p.parallelism),
75            Some(2)
76        );
77    }
78
79    #[test]
80    fn test_default_global_default() {
81        let specific = Some(ConfigParallelism::Default);
82        let global = ConfigParallelism::Default;
83        assert_eq!(derive_parallelism(specific, global), None);
84    }
85
86    #[test]
87    fn test_default_global_adaptive() {
88        let specific = Some(ConfigParallelism::Default);
89        let global = ConfigParallelism::Adaptive;
90        assert_eq!(derive_parallelism(specific, global), None);
91    }
92
93    #[test]
94    fn test_adaptive_any_global() {
95        let specific = Some(ConfigParallelism::Adaptive);
96        let globals = [
97            ConfigParallelism::Default,
98            ConfigParallelism::Adaptive,
99            ConfigParallelism::Fixed(NonZeroU64::new(8).unwrap()),
100        ];
101
102        for global in globals {
103            assert_eq!(derive_parallelism(specific, global), None);
104        }
105    }
106
107    #[test]
108    fn test_fixed_override_global() {
109        let specific = Some(ConfigParallelism::Fixed(NonZeroU64::new(6).unwrap()));
110        let globals = [
111            ConfigParallelism::Default,
112            ConfigParallelism::Adaptive,
113            ConfigParallelism::Fixed(NonZeroU64::new(3).unwrap()),
114        ];
115
116        for global in globals {
117            assert_eq!(
118                derive_parallelism(specific, global).map(|p| p.parallelism),
119                Some(6)
120            );
121        }
122    }
123}