risingwave_common_proc_macro/
session_config.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use bae::FromAttributes;
16use proc_macro_error::{OptionExt, ResultExt, abort};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote, quote_spanned};
19use syn::DeriveInput;
20
21#[derive(FromAttributes)]
22struct Parameter {
23    pub rename: Option<syn::LitStr>,
24    pub alias: Option<syn::Expr>,
25    pub default: syn::Expr,
26    pub flags: Option<syn::LitStr>,
27    pub check_hook: Option<syn::Expr>,
28}
29
30pub(crate) fn derive_config(input: DeriveInput) -> TokenStream {
31    let syn::Data::Struct(syn::DataStruct { fields, .. }) = input.data else {
32        abort!(input, "Only struct is supported");
33    };
34
35    let mut default_fields = vec![];
36    let mut struct_impl_set = vec![];
37    let mut struct_impl_get = vec![];
38    let mut struct_impl_reset = vec![];
39    let mut set_match_branches = vec![];
40    let mut get_match_branches = vec![];
41    let mut reset_match_branches = vec![];
42    let mut show_all_list = vec![];
43    let mut list_all_list = vec![];
44    let mut alias_to_entry_name_branches = vec![];
45    let mut entry_name_flags = vec![];
46
47    for field in fields {
48        let field_ident = field.ident.expect_or_abort("Field need to be named");
49        let ty = field.ty;
50
51        let mut doc_list = vec![];
52        for attr in &field.attrs {
53            if attr.path.is_ident("doc") {
54                let meta = attr.parse_meta().expect_or_abort("Failed to parse meta");
55                if let syn::Meta::NameValue(val) = meta {
56                    if let syn::Lit::Str(desc) = val.lit {
57                        doc_list.push(desc.value().trim().to_owned());
58                    }
59                }
60            }
61        }
62
63        let description: TokenStream = format!("r#\"{}\"#", doc_list.join(" ")).parse().unwrap();
64
65        let attr =
66            Parameter::from_attributes(&field.attrs).expect_or_abort("Failed to parse attribute");
67        let Parameter {
68            rename,
69            alias,
70            default,
71            flags,
72            check_hook: check_hook_name,
73        } = attr;
74
75        let entry_name = if let Some(rename) = rename {
76            if !(rename.value().is_ascii() && rename.value().to_ascii_lowercase() == rename.value())
77            {
78                abort!(rename, "Expect `rename` to be an ascii lower case string");
79            }
80            quote! {#rename}
81        } else {
82            let ident = format_ident!("{}", field_ident.to_string().to_lowercase());
83            quote! {stringify!(#ident)}
84        };
85
86        if let Some(alias) = alias {
87            alias_to_entry_name_branches.push(quote! {
88                #alias => #entry_name.to_string(),
89            })
90        }
91
92        let flags = flags.map(|f| f.value()).unwrap_or_default();
93        let flags: Vec<_> = flags.split('|').map(|str| str.trim()).collect();
94
95        default_fields.push(quote_spanned! {
96            field_ident.span()=>
97            #field_ident: #default.into(),
98        });
99
100        let set_func_name = format_ident!("set_{}_str", field_ident);
101        let set_t_func_name = format_ident!("set_{}", field_ident);
102        let set_t_inner_func_name = format_ident!("set_{}_inner", field_ident);
103        let set_t_func_doc: TokenStream =
104            format!("r#\"Set parameter {} by a typed value.\"#", entry_name)
105                .parse()
106                .unwrap();
107        let set_func_doc: TokenStream = format!("r#\"Set parameter {} by a string.\"#", entry_name)
108            .parse()
109            .unwrap();
110
111        let gen_set_func_name = if flags.contains(&"SETTER") {
112            set_t_inner_func_name.clone()
113        } else {
114            set_t_func_name.clone()
115        };
116
117        let check_hook = if let Some(check_hook_name) = check_hook_name {
118            quote! {
119                #check_hook_name(&val).map_err(|e| {
120                    SessionConfigError::InvalidValue {
121                        entry: #entry_name,
122                        value: val.to_string(),
123                        source: anyhow::anyhow!(e),
124                    }
125                })?;
126            }
127        } else {
128            quote! {}
129        };
130
131        let report_hook = if flags.contains(&"REPORT") {
132            quote! {
133                if self.#field_ident != val {
134                    reporter.report_status(#entry_name, val.to_string());
135                }
136            }
137        } else {
138            quote! {}
139        };
140
141        // An easy way to check if the type is bool and use a different parse function.
142        let parse = if quote!(#ty).to_string() == "bool" {
143            quote!(risingwave_common::cast::str_to_bool)
144        } else {
145            quote!(<#ty as ::std::str::FromStr>::from_str)
146        };
147
148        struct_impl_set.push(quote! {
149            #[doc = #set_func_doc]
150            pub fn #set_func_name(
151                &mut self,
152                val: &str,
153                reporter: &mut impl ConfigReporter
154            ) -> SessionConfigResult<String> {
155                let val_t = #parse(val).map_err(|e| {
156                    SessionConfigError::InvalidValue {
157                        entry: #entry_name,
158                        value: val.to_string(),
159                        source: anyhow::anyhow!(e),
160                    }
161                })?;
162
163                self.#set_t_func_name(val_t, reporter).map(|val| val.to_string())
164            }
165
166            #[doc = #set_t_func_doc]
167            pub fn #gen_set_func_name(
168                &mut self,
169                val: #ty,
170                reporter: &mut impl ConfigReporter
171            ) -> SessionConfigResult<#ty> {
172                #check_hook
173                #report_hook
174
175                self.#field_ident = val.clone();
176                Ok(val)
177            }
178
179        });
180
181        let reset_func_name = format_ident!("reset_{}", field_ident);
182        struct_impl_reset.push(quote! {
183
184        #[allow(clippy::useless_conversion)]
185        pub fn #reset_func_name(&mut self, reporter: &mut impl ConfigReporter) -> String {
186                let val = #default;
187                #report_hook
188                self.#field_ident = val.into();
189                self.#field_ident.to_string()
190            }
191        });
192
193        let get_func_name = format_ident!("{}_str", field_ident);
194        let get_t_func_name = format_ident!("{}", field_ident);
195        let get_func_doc: TokenStream =
196            format!("r#\"Get a value string of parameter {} \"#", entry_name)
197                .parse()
198                .unwrap();
199        let get_t_func_doc: TokenStream =
200            format!("r#\"Get a typed value of parameter {} \"#", entry_name)
201                .parse()
202                .unwrap();
203
204        struct_impl_get.push(quote! {
205            #[doc = #get_func_doc]
206            pub fn #get_func_name(&self) -> String {
207                self.#get_t_func_name().to_string()
208            }
209
210            #[doc = #get_t_func_doc]
211            pub fn #get_t_func_name(&self) -> #ty {
212                self.#field_ident.clone()
213            }
214
215        });
216
217        get_match_branches.push(quote! {
218            #entry_name => Ok(self.#get_func_name()),
219        });
220
221        set_match_branches.push(quote! {
222            #entry_name => self.#set_func_name(&value, reporter),
223        });
224
225        reset_match_branches.push(quote! {
226            #entry_name => Ok(self.#reset_func_name(reporter)),
227        });
228
229        let var_info = quote! {
230            VariableInfo {
231                name: #entry_name.to_string(),
232                setting: self.#field_ident.to_string(),
233                description : #description.to_string(),
234            },
235        };
236        list_all_list.push(var_info.clone());
237
238        let no_show_all = flags.contains(&"NO_SHOW_ALL");
239        let no_show_all_flag: TokenStream = no_show_all.to_string().parse().unwrap();
240        if !no_show_all {
241            show_all_list.push(var_info);
242        }
243
244        let no_alter_sys_flag: TokenStream =
245            flags.contains(&"NO_ALTER_SYS").to_string().parse().unwrap();
246
247        entry_name_flags.push(
248            quote! {
249                (#entry_name, ParamFlags {no_show_all: #no_show_all_flag, no_alter_sys: #no_alter_sys_flag})
250            }
251        );
252    }
253
254    let struct_ident = input.ident;
255    quote! {
256        use std::collections::HashMap;
257        use std::sync::LazyLock;
258        static PARAM_NAME_FLAGS: LazyLock<HashMap<&'static str, ParamFlags>> = LazyLock::new(|| HashMap::from([#(#entry_name_flags, )*]));
259
260        struct ParamFlags {
261            no_show_all: bool,
262            no_alter_sys: bool,
263        }
264
265        impl Default for #struct_ident {
266            #[allow(clippy::useless_conversion)]
267            fn default() -> Self {
268                Self {
269                    #(#default_fields)*
270                }
271            }
272        }
273
274        impl #struct_ident {
275            fn new() -> Self {
276                Default::default()
277            }
278
279            pub fn alias_to_entry_name(key_name: &str) -> String {
280                let key_name = key_name.to_ascii_lowercase();
281                match key_name.as_str() {
282                    #(#alias_to_entry_name_branches)*
283                    _ => key_name,
284                }
285            }
286
287            #(#struct_impl_get)*
288
289            #(#struct_impl_set)*
290
291            #(#struct_impl_reset)*
292
293            /// Set a parameter given it's name and value string.
294            pub fn set(&mut self, key_name: &str, value: String, reporter: &mut impl ConfigReporter) -> SessionConfigResult<String> {
295                let key_name = Self::alias_to_entry_name(key_name);
296                match key_name.as_ref() {
297                    #(#set_match_branches)*
298                    _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())),
299                }
300            }
301
302            /// Get a parameter by it's name.
303            pub fn get(&self, key_name: &str) -> SessionConfigResult<String> {
304                let key_name = Self::alias_to_entry_name(key_name);
305                match key_name.as_ref() {
306                    #(#get_match_branches)*
307                    _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())),
308                }
309            }
310
311            /// Reset a parameter by it's name.
312            pub fn reset(&mut self, key_name: &str, reporter: &mut impl ConfigReporter) -> SessionConfigResult<String> {
313                let key_name = Self::alias_to_entry_name(key_name);
314                match key_name.as_ref() {
315                    #(#reset_match_branches)*
316                    _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())),
317                }
318            }
319
320            /// Show all parameters except those specified `NO_SHOW_ALL`.
321            pub fn show_all(&self) -> Vec<VariableInfo> {
322                vec![
323                    #(#show_all_list)*
324                ]
325            }
326
327            /// List all parameters
328            pub fn list_all(&self) -> Vec<VariableInfo> {
329                vec![
330                    #(#list_all_list)*
331                ]
332            }
333
334            /// Check if `SessionConfig` has a parameter.
335            pub fn contains_param(key_name: &str) -> bool {
336                let key_name = Self::alias_to_entry_name(key_name);
337                PARAM_NAME_FLAGS.contains_key(key_name.as_str())
338            }
339
340            /// Check if `SessionConfig` has a parameter.
341            pub fn check_no_alter_sys(key_name: &str) -> SessionConfigResult<bool> {
342                let key_name = Self::alias_to_entry_name(key_name);
343                let flags = PARAM_NAME_FLAGS.get(key_name.as_str()).ok_or_else(|| SessionConfigError::UnrecognizedEntry(key_name.to_string()))?;
344                Ok(flags.no_alter_sys)
345            }
346        }
347    }
348}