risingwave_common/config/
merge.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::str::FromStr as _;
16
17use anyhow::Context as _;
18use serde::Serialize;
19use serde::de::DeserializeOwned;
20use toml::map::Entry;
21use toml::{Table, Value};
22
23use crate::config::StreamingConfig;
24
25def_anyhow_newtype! { pub ConfigMergeError }
26
27/// Extract the section at `partial_path` from `partial`, merge it into `base` to override entries.
28///
29/// Tables will be merged recursively, while other fields (including arrays) will be replaced by
30/// the `partial` config, if exists.
31///
32/// Returns an error if any of the input is invalid, or the merged config cannot be parsed.
33/// Returns `None` if there's nothing to override.
34pub fn merge_config<C: Serialize + DeserializeOwned + Clone>(
35    base: &C,
36    partial: &str,
37    partial_path: impl IntoIterator<Item = &str>,
38) -> Result<Option<C>, ConfigMergeError> {
39    let partial_table = {
40        let mut partial_table =
41            Table::from_str(partial).context("failed to parse partial config")?;
42        for k in partial_path {
43            if let Some(v) = partial_table.remove(k)
44                && let Value::Table(t) = v
45            {
46                partial_table = t;
47            } else {
48                // The section to override is not relevant.
49                return Ok(None);
50            }
51        }
52        partial_table
53    };
54
55    if partial_table.is_empty() {
56        // Nothing to override.
57        return Ok(None);
58    }
59
60    let mut base_table = Table::try_from(base).context("failed to serialize base config")?;
61
62    fn merge_table(base_table: &mut Table, partial_table: Table) {
63        for (k, v) in partial_table {
64            match base_table.entry(k) {
65                Entry::Vacant(entry) => {
66                    // Unrecognized entry might be tolerated.
67                    // So we simply keep it and defer the error (if any) until final deserialization.
68                    entry.insert(v);
69                }
70                Entry::Occupied(mut entry) => {
71                    let base_v = entry.get_mut();
72                    merge_value(base_v, v);
73                }
74            }
75        }
76    }
77
78    fn merge_value(base: &mut Value, partial: Value) {
79        if let Value::Table(base_table) = base
80            && let Value::Table(partial_table) = partial
81        {
82            merge_table(base_table, partial_table);
83        } else {
84            // We don't validate the type, but defer until final deserialization.
85            *base = partial;
86        }
87    }
88
89    merge_table(&mut base_table, partial_table);
90
91    let merged: C = base_table
92        .try_into()
93        .context("failed to deserialize merged config")?;
94
95    Ok(Some(merged))
96}
97
98/// Extract the `streaming` section from `partial`, merge it into `base` to override entries.
99///
100/// See [`merge_config`] for more details.
101pub fn merge_streaming_config_section(
102    base: &StreamingConfig,
103    partial: &str,
104) -> Result<Option<StreamingConfig>, ConfigMergeError> {
105    merge_config(base, partial, ["streaming"])
106}
107
108#[cfg(test)]
109#[allow(clippy::bool_assert_comparison)]
110mod tests {
111    use thiserror_ext::AsReport;
112
113    use super::*;
114    use crate::config::StreamingConfig;
115
116    #[test]
117    fn test_merge_streaming_config() {
118        let base = StreamingConfig::default();
119        assert_ne!(base.unsafe_enable_strict_consistency, false);
120        assert_ne!(base.developer.chunk_size, 114514);
121        assert_ne!(
122            base.developer.compute_client_config.connect_timeout_secs,
123            114514
124        );
125
126        let partial = r#"
127            [streaming]
128            unsafe_enable_strict_consistency = false
129
130            [streaming.developer]
131            chunk_size = 114514
132            compute_client_config = { connect_timeout_secs = 114514 }
133        "#;
134        let merged = merge_streaming_config_section(&base, partial)
135            .unwrap()
136            .unwrap();
137
138        // Demonstrate that the entries are merged.
139        assert_eq!(merged.unsafe_enable_strict_consistency, false);
140        assert_eq!(merged.developer.chunk_size, 114514);
141        assert_eq!(
142            merged.developer.compute_client_config.connect_timeout_secs,
143            114514
144        );
145
146        // Demonstrate that the rest of the config is not affected.
147        {
148            let mut merged = merged;
149            merged.unsafe_enable_strict_consistency = base.unsafe_enable_strict_consistency;
150            merged.developer.chunk_size = base.developer.chunk_size;
151            merged.developer.compute_client_config.connect_timeout_secs =
152                base.developer.compute_client_config.connect_timeout_secs;
153
154            pretty_assertions::assert_eq!(format!("{base:?}"), format!("{merged:?}"));
155        }
156    }
157
158    #[test]
159    fn test_not_relevant() {
160        let base = StreamingConfig::default();
161        let partial = r#"
162            [batch.developer]
163            chunk_size = 114514
164        "#;
165        let merged = merge_streaming_config_section(&base, partial).unwrap();
166        assert!(
167            merged.is_none(),
168            "nothing to override, but got: {merged:#?}"
169        );
170    }
171
172    #[test]
173    fn test_nothing_to_override() {
174        let base = StreamingConfig::default();
175        let partial = r#"
176            [streaming]
177        "#;
178        let merged = merge_streaming_config_section(&base, partial).unwrap();
179        assert!(
180            merged.is_none(),
181            "nothing to override, but got: {merged:#?}"
182        );
183    }
184
185    #[test]
186    fn test_unrecognized_entry() {
187        let base = StreamingConfig::default();
188        let partial = r#"
189            [streaming]
190            no_such_entry = 114514
191
192            [streaming.developer]
193            no_such_dev_entry = 1919810
194        "#;
195        let merged = merge_streaming_config_section(&base, partial)
196            .unwrap()
197            .unwrap();
198
199        let unrecognized = merged.unrecognized.into_inner();
200        assert_eq!(unrecognized.len(), 1);
201        assert_eq!(unrecognized["no_such_entry"], 114514);
202
203        let dev_unrecognized = merged.developer.unrecognized.into_inner();
204        assert_eq!(dev_unrecognized.len(), 1);
205        assert_eq!(dev_unrecognized["no_such_dev_entry"], 1919810);
206    }
207
208    #[test]
209    fn test_invalid_type() {
210        let base = StreamingConfig::default();
211        let partial = r#"
212            [streaming.developer]
213            chunk_size = "omakase"
214        "#;
215        let error = merge_streaming_config_section(&base, partial).unwrap_err();
216        expect_test::expect![[r#"
217            failed to deserialize merged config: invalid type: string "omakase", expected usize
218            in `developer.chunk_size`
219        "#]]
220        .assert_eq(&error.to_report_string());
221    }
222
223    // Even though we accept `stream_` prefixed config key when deserializing the config, since
224    // we perform merging atop of the raw `toml::Value`, we don't have the information about
225    // the aliasing. Therefore, using a prefixed config key in config override will result in
226    // a duplicate field error.
227    #[test]
228    fn tets_override_with_legacy_prefixed_config() {
229        let base = StreamingConfig::default();
230        let partial = r#"
231            [streaming.developer]
232            stream_chunk_size = 114514
233        "#;
234        let error = merge_streaming_config_section(&base, partial).unwrap_err();
235        expect_test::expect![[r#"
236            failed to deserialize merged config: duplicate field `chunk_size`
237            in `developer`
238        "#]]
239        .assert_eq(&error.to_report_string());
240    }
241}