risingwave_common_proc_macro/
session_config.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use bae::FromAttributes;
16use proc_macro_error::{OptionExt, ResultExt, abort};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote, quote_spanned};
19use syn::DeriveInput;
20
21#[derive(FromAttributes)]
22struct Parameter {
23    pub rename: Option<syn::LitStr>,
24    pub alias: Option<syn::Expr>,
25    pub default: syn::Expr,
26    pub flags: Option<syn::LitStr>,
27    pub check_hook: Option<syn::Expr>,
28}
29
30pub(crate) fn derive_config(input: DeriveInput) -> TokenStream {
31    let syn::Data::Struct(syn::DataStruct { fields, .. }) = input.data else {
32        abort!(input, "Only struct is supported");
33    };
34
35    let mut default_fields = vec![];
36    let mut struct_impl_set = vec![];
37    let mut struct_impl_get = vec![];
38    let mut struct_impl_reset = vec![];
39    let mut set_match_branches = vec![];
40    let mut get_match_branches = vec![];
41    let mut reset_match_branches = vec![];
42    let mut show_all_list = vec![];
43    let mut list_all_list = vec![];
44    let mut alias_to_entry_name_branches = vec![];
45    let mut entry_name_flags = vec![];
46    // Fields and entries for the generated `SessionInitConfig`, i.e. parameters flagged with
47    // `SESSION_INIT` that can be seeded from `[session_init]` in `risingwave.toml`.
48    let mut session_init_fields = vec![];
49    let mut session_init_entries = vec![];
50
51    for field in fields {
52        let field_ident = field.ident.expect_or_abort("Field need to be named");
53        let ty = field.ty;
54
55        let mut doc_list = vec![];
56        for attr in &field.attrs {
57            if attr.path.is_ident("doc") {
58                let meta = attr.parse_meta().expect_or_abort("Failed to parse meta");
59                if let syn::Meta::NameValue(val) = meta
60                    && let syn::Lit::Str(desc) = val.lit
61                {
62                    doc_list.push(desc.value().trim().to_owned());
63                }
64            }
65        }
66
67        let description: TokenStream = format!("r#\"{}\"#", doc_list.join(" ")).parse().unwrap();
68
69        let attr =
70            Parameter::from_attributes(&field.attrs).expect_or_abort("Failed to parse attribute");
71        let Parameter {
72            rename,
73            alias,
74            default,
75            flags,
76            check_hook: check_hook_name,
77        } = attr;
78
79        let entry_name = if let Some(rename) = rename {
80            if !(rename.value().is_ascii() && rename.value().to_ascii_lowercase() == rename.value())
81            {
82                abort!(rename, "Expect `rename` to be an ascii lower case string");
83            }
84            quote! {#rename}
85        } else {
86            let ident = format_ident!("{}", field_ident.to_string().to_lowercase());
87            quote! {stringify!(#ident)}
88        };
89
90        if let Some(alias) = alias {
91            alias_to_entry_name_branches.push(quote! {
92                #alias => #entry_name.to_string(),
93            })
94        }
95
96        let flags = flags.map(|f| f.value()).unwrap_or_default();
97        let flags: Vec<_> = flags.split('|').map(|str| str.trim()).collect();
98
99        default_fields.push(quote_spanned! {
100            field_ident.span()=>
101            #field_ident: #default.into(),
102        });
103
104        let set_func_name = format_ident!("set_{}_str", field_ident);
105        let set_t_func_name = format_ident!("set_{}", field_ident);
106        let set_t_inner_func_name = format_ident!("set_{}_inner", field_ident);
107        let set_t_func_doc: TokenStream =
108            format!("r#\"Set parameter {} by a typed value.\"#", entry_name)
109                .parse()
110                .unwrap();
111        let set_func_doc: TokenStream = format!("r#\"Set parameter {} by a string.\"#", entry_name)
112            .parse()
113            .unwrap();
114
115        let gen_set_func_name = if flags.contains(&"SETTER") {
116            set_t_inner_func_name.clone()
117        } else {
118            set_t_func_name.clone()
119        };
120
121        let check_hook = if let Some(check_hook_name) = check_hook_name {
122            quote! {
123                #check_hook_name(&val).map_err(|e| {
124                    SessionConfigError::InvalidValue {
125                        entry: #entry_name,
126                        value: val.to_string(),
127                        source: anyhow::anyhow!(e),
128                    }
129                })?;
130            }
131        } else {
132            quote! {}
133        };
134
135        let report_hook = if flags.contains(&"REPORT") {
136            quote! {
137                if self.#field_ident != val {
138                    reporter.report_status(#entry_name, val.to_string());
139                }
140            }
141        } else {
142            quote! {}
143        };
144
145        // An easy way to check if the type is bool and use a different parse function.
146        let parse = if quote!(#ty).to_string() == "bool" {
147            quote!(risingwave_common::cast::str_to_bool)
148        } else {
149            quote!(<#ty as ::std::str::FromStr>::from_str)
150        };
151
152        struct_impl_set.push(quote! {
153            #[doc = #set_func_doc]
154            pub fn #set_func_name(
155                &mut self,
156                val: &str,
157                reporter: &mut impl ConfigReporter
158            ) -> SessionConfigResult<String> {
159                let val_t = #parse(val).map_err(|e| {
160                    SessionConfigError::InvalidValue {
161                        entry: #entry_name,
162                        value: val.to_string(),
163                        source: anyhow::anyhow!(e),
164                    }
165                })?;
166
167                self.#set_t_func_name(val_t, reporter).map(|val| val.to_string())
168            }
169
170            #[doc = #set_t_func_doc]
171            pub fn #gen_set_func_name(
172                &mut self,
173                val: #ty,
174                reporter: &mut impl ConfigReporter
175            ) -> SessionConfigResult<#ty> {
176                #check_hook
177                #report_hook
178
179                self.#field_ident = val.clone();
180                Ok(val)
181            }
182
183        });
184
185        let reset_func_name = format_ident!("reset_{}", field_ident);
186        struct_impl_reset.push(quote! {
187
188        #[allow(clippy::useless_conversion)]
189        pub fn #reset_func_name(&mut self, reporter: &mut impl ConfigReporter) -> String {
190                let val = #default;
191                #report_hook
192                self.#field_ident = val.into();
193                self.#field_ident.to_string()
194            }
195        });
196
197        let get_func_name = format_ident!("{}_str", field_ident);
198        let get_t_func_name = format_ident!("{}", field_ident);
199        let get_func_doc: TokenStream =
200            format!("r#\"Get a value string of parameter {} \"#", entry_name)
201                .parse()
202                .unwrap();
203        let get_t_func_doc: TokenStream =
204            format!("r#\"Get a typed value of parameter {} \"#", entry_name)
205                .parse()
206                .unwrap();
207
208        struct_impl_get.push(quote! {
209            #[doc = #get_func_doc]
210            pub fn #get_func_name(&self) -> String {
211                self.#get_t_func_name().to_string()
212            }
213
214            #[doc = #get_t_func_doc]
215            pub fn #get_t_func_name(&self) -> #ty {
216                self.#field_ident.clone()
217            }
218
219        });
220
221        get_match_branches.push(quote! {
222            #entry_name => Ok(self.#get_func_name()),
223        });
224
225        set_match_branches.push(quote! {
226            #entry_name => self.#set_func_name(&value, reporter),
227        });
228
229        reset_match_branches.push(quote! {
230            #entry_name => Ok(self.#reset_func_name(reporter)),
231        });
232
233        let var_info = quote! {
234            VariableInfo {
235                name: #entry_name.to_string(),
236                setting: self.#field_ident.to_string(),
237                description : #description.to_string(),
238            },
239        };
240        list_all_list.push(var_info.clone());
241
242        let no_show_all = flags.contains(&"NO_SHOW_ALL");
243        let no_show_all_flag: TokenStream = no_show_all.to_string().parse().unwrap();
244        if !no_show_all {
245            show_all_list.push(var_info);
246        }
247
248        let no_alter_sys_flag: TokenStream =
249            flags.contains(&"NO_ALTER_SYS").to_string().parse().unwrap();
250
251        entry_name_flags.push(
252            quote! {
253                (#entry_name, ParamFlags {no_show_all: #no_show_all_flag, no_alter_sys: #no_alter_sys_flag})
254            }
255        );
256
257        // Parameters flagged with `SESSION_INIT` become a field in the generated
258        // `SessionInitConfig`. The value is kept as a raw `Option<String>` so that a parameter
259        // omitted from `risingwave.toml` (`None`) can be distinguished from one explicitly set to
260        // its logical default such as `"default"` (`Some("default")`).
261        if flags.contains(&"SESSION_INIT") {
262            let doc_string = doc_list.join(" ");
263            session_init_fields.push(quote! {
264                #[doc = #doc_string]
265                #[serde(default, with = "crate::config::none_as_empty_string")]
266                pub #field_ident: Option<String>,
267            });
268            session_init_entries.push(quote! {
269                (#entry_name, &self.#field_ident),
270            });
271        }
272    }
273
274    let struct_ident = input.ident;
275    quote! {
276        /// The section `[session_init]` in `risingwave.toml`, generated from the `SESSION_INIT`-flagged
277        /// fields of [`SessionConfig`].
278        ///
279        /// It seeds the corresponding persisted session parameters into the meta store during
280        /// **cluster bootstrap only**. The precedence is:
281        ///
282        /// 1. Persisted value in the meta store (`session_parameter`)
283        /// 2. Explicit value in `[session_init]`
284        /// 3. Built-in `SessionConfig::default()`
285        ///
286        /// Editing `[session_init]` after a cluster has been bootstrapped does not change the
287        /// effective value of an already-persisted parameter. See the RFC for details.
288        #[derive(Clone, Debug, Default, Serialize, Deserialize, ConfigDoc, PartialEq)]
289        #[serde(deny_unknown_fields)]
290        pub struct SessionInitConfig {
291            #(#session_init_fields)*
292        }
293
294        impl SessionInitConfig {
295            /// Returns the explicitly-configured `(session parameter entry name, value)` pairs.
296            /// Parameters omitted from `[session_init]` are not included.
297            pub fn entries(&self) -> Vec<(&'static str, &str)> {
298                [
299                    #(#session_init_entries)*
300                ]
301                .into_iter()
302                .filter_map(|(name, value): (&'static str, &Option<String>)| {
303                    value.as_deref().map(|value| (name, value))
304                })
305                .collect()
306            }
307        }
308
309        use std::collections::HashMap;
310        use std::sync::LazyLock;
311        static PARAM_NAME_FLAGS: LazyLock<HashMap<&'static str, ParamFlags>> = LazyLock::new(|| HashMap::from([#(#entry_name_flags, )*]));
312
313        struct ParamFlags {
314            no_show_all: bool,
315            no_alter_sys: bool,
316        }
317
318        impl Default for #struct_ident {
319            #[allow(clippy::useless_conversion)]
320            fn default() -> Self {
321                Self {
322                    #(#default_fields)*
323                }
324            }
325        }
326
327        impl #struct_ident {
328            fn new() -> Self {
329                Default::default()
330            }
331
332            pub fn alias_to_entry_name(key_name: &str) -> String {
333                let key_name = key_name.to_ascii_lowercase();
334                match key_name.as_str() {
335                    #(#alias_to_entry_name_branches)*
336                    _ => key_name,
337                }
338            }
339
340            #(#struct_impl_get)*
341
342            #(#struct_impl_set)*
343
344            #(#struct_impl_reset)*
345
346            /// Set a parameter given it's name and value string.
347            pub fn set(&mut self, key_name: &str, value: String, reporter: &mut impl ConfigReporter) -> SessionConfigResult<String> {
348                let key_name = Self::alias_to_entry_name(key_name);
349                match key_name.as_ref() {
350                    #(#set_match_branches)*
351                    _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())),
352                }
353            }
354
355            /// Get a parameter by it's name.
356            pub fn get(&self, key_name: &str) -> SessionConfigResult<String> {
357                let key_name = Self::alias_to_entry_name(key_name);
358                match key_name.as_ref() {
359                    #(#get_match_branches)*
360                    _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())),
361                }
362            }
363
364            /// Reset a parameter by it's name.
365            pub fn reset(&mut self, key_name: &str, reporter: &mut impl ConfigReporter) -> SessionConfigResult<String> {
366                let key_name = Self::alias_to_entry_name(key_name);
367                match key_name.as_ref() {
368                    #(#reset_match_branches)*
369                    _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())),
370                }
371            }
372
373            /// Show all parameters except those specified `NO_SHOW_ALL`.
374            pub fn show_all(&self) -> Vec<VariableInfo> {
375                vec![
376                    #(#show_all_list)*
377                ]
378            }
379
380            /// List all parameters
381            pub fn list_all(&self) -> Vec<VariableInfo> {
382                vec![
383                    #(#list_all_list)*
384                ]
385            }
386
387            /// Check if `SessionConfig` has a parameter.
388            pub fn contains_param(key_name: &str) -> bool {
389                let key_name = Self::alias_to_entry_name(key_name);
390                PARAM_NAME_FLAGS.contains_key(key_name.as_str())
391            }
392
393            /// Check if `SessionConfig` has a parameter.
394            pub fn check_no_alter_sys(key_name: &str) -> SessionConfigResult<bool> {
395                let key_name = Self::alias_to_entry_name(key_name);
396                let flags = PARAM_NAME_FLAGS.get(key_name.as_str()).ok_or_else(|| SessionConfigError::UnrecognizedEntry(key_name.to_string()))?;
397                Ok(flags.no_alter_sys)
398            }
399        }
400    }
401}