risingwave_common_proc_macro/
session_config.rs1use bae::FromAttributes;
16use proc_macro_error::{OptionExt, ResultExt, abort};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote, quote_spanned};
19use syn::DeriveInput;
20
21#[derive(FromAttributes)]
22struct Parameter {
23 pub rename: Option<syn::LitStr>,
24 pub alias: Option<syn::Expr>,
25 pub default: syn::Expr,
26 pub flags: Option<syn::LitStr>,
27 pub check_hook: Option<syn::Expr>,
28}
29
30pub(crate) fn derive_config(input: DeriveInput) -> TokenStream {
31 let syn::Data::Struct(syn::DataStruct { fields, .. }) = input.data else {
32 abort!(input, "Only struct is supported");
33 };
34
35 let mut default_fields = vec![];
36 let mut struct_impl_set = vec![];
37 let mut struct_impl_get = vec![];
38 let mut struct_impl_reset = vec![];
39 let mut set_match_branches = vec![];
40 let mut get_match_branches = vec![];
41 let mut reset_match_branches = vec![];
42 let mut show_all_list = vec![];
43 let mut list_all_list = vec![];
44 let mut alias_to_entry_name_branches = vec![];
45 let mut entry_name_flags = vec![];
46
47 for field in fields {
48 let field_ident = field.ident.expect_or_abort("Field need to be named");
49 let ty = field.ty;
50
51 let mut doc_list = vec![];
52 for attr in &field.attrs {
53 if attr.path.is_ident("doc") {
54 let meta = attr.parse_meta().expect_or_abort("Failed to parse meta");
55 if let syn::Meta::NameValue(val) = meta {
56 if let syn::Lit::Str(desc) = val.lit {
57 doc_list.push(desc.value().trim().to_owned());
58 }
59 }
60 }
61 }
62
63 let description: TokenStream = format!("r#\"{}\"#", doc_list.join(" ")).parse().unwrap();
64
65 let attr =
66 Parameter::from_attributes(&field.attrs).expect_or_abort("Failed to parse attribute");
67 let Parameter {
68 rename,
69 alias,
70 default,
71 flags,
72 check_hook: check_hook_name,
73 } = attr;
74
75 let entry_name = if let Some(rename) = rename {
76 if !(rename.value().is_ascii() && rename.value().to_ascii_lowercase() == rename.value())
77 {
78 abort!(rename, "Expect `rename` to be an ascii lower case string");
79 }
80 quote! {#rename}
81 } else {
82 let ident = format_ident!("{}", field_ident.to_string().to_lowercase());
83 quote! {stringify!(#ident)}
84 };
85
86 if let Some(alias) = alias {
87 alias_to_entry_name_branches.push(quote! {
88 #alias => #entry_name.to_string(),
89 })
90 }
91
92 let flags = flags.map(|f| f.value()).unwrap_or_default();
93 let flags: Vec<_> = flags.split('|').map(|str| str.trim()).collect();
94
95 default_fields.push(quote_spanned! {
96 field_ident.span()=>
97 #field_ident: #default.into(),
98 });
99
100 let set_func_name = format_ident!("set_{}_str", field_ident);
101 let set_t_func_name = format_ident!("set_{}", field_ident);
102 let set_t_inner_func_name = format_ident!("set_{}_inner", field_ident);
103 let set_t_func_doc: TokenStream =
104 format!("r#\"Set parameter {} by a typed value.\"#", entry_name)
105 .parse()
106 .unwrap();
107 let set_func_doc: TokenStream = format!("r#\"Set parameter {} by a string.\"#", entry_name)
108 .parse()
109 .unwrap();
110
111 let gen_set_func_name = if flags.contains(&"SETTER") {
112 set_t_inner_func_name.clone()
113 } else {
114 set_t_func_name.clone()
115 };
116
117 let check_hook = if let Some(check_hook_name) = check_hook_name {
118 quote! {
119 #check_hook_name(&val).map_err(|e| {
120 SessionConfigError::InvalidValue {
121 entry: #entry_name,
122 value: val.to_string(),
123 source: anyhow::anyhow!(e),
124 }
125 })?;
126 }
127 } else {
128 quote! {}
129 };
130
131 let report_hook = if flags.contains(&"REPORT") {
132 quote! {
133 if self.#field_ident != val {
134 reporter.report_status(#entry_name, val.to_string());
135 }
136 }
137 } else {
138 quote! {}
139 };
140
141 let parse = if quote!(#ty).to_string() == "bool" {
143 quote!(risingwave_common::cast::str_to_bool)
144 } else {
145 quote!(<#ty as ::std::str::FromStr>::from_str)
146 };
147
148 struct_impl_set.push(quote! {
149 #[doc = #set_func_doc]
150 pub fn #set_func_name(
151 &mut self,
152 val: &str,
153 reporter: &mut impl ConfigReporter
154 ) -> SessionConfigResult<String> {
155 let val_t = #parse(val).map_err(|e| {
156 SessionConfigError::InvalidValue {
157 entry: #entry_name,
158 value: val.to_string(),
159 source: anyhow::anyhow!(e),
160 }
161 })?;
162
163 self.#set_t_func_name(val_t, reporter).map(|val| val.to_string())
164 }
165
166 #[doc = #set_t_func_doc]
167 pub fn #gen_set_func_name(
168 &mut self,
169 val: #ty,
170 reporter: &mut impl ConfigReporter
171 ) -> SessionConfigResult<#ty> {
172 #check_hook
173 #report_hook
174
175 self.#field_ident = val.clone();
176 Ok(val)
177 }
178
179 });
180
181 let reset_func_name = format_ident!("reset_{}", field_ident);
182 struct_impl_reset.push(quote! {
183
184 #[allow(clippy::useless_conversion)]
185 pub fn #reset_func_name(&mut self, reporter: &mut impl ConfigReporter) -> String {
186 let val = #default;
187 #report_hook
188 self.#field_ident = val.into();
189 self.#field_ident.to_string()
190 }
191 });
192
193 let get_func_name = format_ident!("{}_str", field_ident);
194 let get_t_func_name = format_ident!("{}", field_ident);
195 let get_func_doc: TokenStream =
196 format!("r#\"Get a value string of parameter {} \"#", entry_name)
197 .parse()
198 .unwrap();
199 let get_t_func_doc: TokenStream =
200 format!("r#\"Get a typed value of parameter {} \"#", entry_name)
201 .parse()
202 .unwrap();
203
204 struct_impl_get.push(quote! {
205 #[doc = #get_func_doc]
206 pub fn #get_func_name(&self) -> String {
207 self.#get_t_func_name().to_string()
208 }
209
210 #[doc = #get_t_func_doc]
211 pub fn #get_t_func_name(&self) -> #ty {
212 self.#field_ident.clone()
213 }
214
215 });
216
217 get_match_branches.push(quote! {
218 #entry_name => Ok(self.#get_func_name()),
219 });
220
221 set_match_branches.push(quote! {
222 #entry_name => self.#set_func_name(&value, reporter),
223 });
224
225 reset_match_branches.push(quote! {
226 #entry_name => Ok(self.#reset_func_name(reporter)),
227 });
228
229 let var_info = quote! {
230 VariableInfo {
231 name: #entry_name.to_string(),
232 setting: self.#field_ident.to_string(),
233 description : #description.to_string(),
234 },
235 };
236 list_all_list.push(var_info.clone());
237
238 let no_show_all = flags.contains(&"NO_SHOW_ALL");
239 let no_show_all_flag: TokenStream = no_show_all.to_string().parse().unwrap();
240 if !no_show_all {
241 show_all_list.push(var_info);
242 }
243
244 let no_alter_sys_flag: TokenStream =
245 flags.contains(&"NO_ALTER_SYS").to_string().parse().unwrap();
246
247 entry_name_flags.push(
248 quote! {
249 (#entry_name, ParamFlags {no_show_all: #no_show_all_flag, no_alter_sys: #no_alter_sys_flag})
250 }
251 );
252 }
253
254 let struct_ident = input.ident;
255 quote! {
256 use std::collections::HashMap;
257 use std::sync::LazyLock;
258 static PARAM_NAME_FLAGS: LazyLock<HashMap<&'static str, ParamFlags>> = LazyLock::new(|| HashMap::from([#(#entry_name_flags, )*]));
259
260 struct ParamFlags {
261 no_show_all: bool,
262 no_alter_sys: bool,
263 }
264
265 impl Default for #struct_ident {
266 #[allow(clippy::useless_conversion)]
267 fn default() -> Self {
268 Self {
269 #(#default_fields)*
270 }
271 }
272 }
273
274 impl #struct_ident {
275 fn new() -> Self {
276 Default::default()
277 }
278
279 pub fn alias_to_entry_name(key_name: &str) -> String {
280 let key_name = key_name.to_ascii_lowercase();
281 match key_name.as_str() {
282 #(#alias_to_entry_name_branches)*
283 _ => key_name,
284 }
285 }
286
287 #(#struct_impl_get)*
288
289 #(#struct_impl_set)*
290
291 #(#struct_impl_reset)*
292
293 pub fn set(&mut self, key_name: &str, value: String, reporter: &mut impl ConfigReporter) -> SessionConfigResult<String> {
295 let key_name = Self::alias_to_entry_name(key_name);
296 match key_name.as_ref() {
297 #(#set_match_branches)*
298 _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())),
299 }
300 }
301
302 pub fn get(&self, key_name: &str) -> SessionConfigResult<String> {
304 let key_name = Self::alias_to_entry_name(key_name);
305 match key_name.as_ref() {
306 #(#get_match_branches)*
307 _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())),
308 }
309 }
310
311 pub fn reset(&mut self, key_name: &str, reporter: &mut impl ConfigReporter) -> SessionConfigResult<String> {
313 let key_name = Self::alias_to_entry_name(key_name);
314 match key_name.as_ref() {
315 #(#reset_match_branches)*
316 _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())),
317 }
318 }
319
320 pub fn show_all(&self) -> Vec<VariableInfo> {
322 vec![
323 #(#show_all_list)*
324 ]
325 }
326
327 pub fn list_all(&self) -> Vec<VariableInfo> {
329 vec![
330 #(#list_all_list)*
331 ]
332 }
333
334 pub fn contains_param(key_name: &str) -> bool {
336 let key_name = Self::alias_to_entry_name(key_name);
337 PARAM_NAME_FLAGS.contains_key(key_name.as_str())
338 }
339
340 pub fn check_no_alter_sys(key_name: &str) -> SessionConfigResult<bool> {
342 let key_name = Self::alias_to_entry_name(key_name);
343 let flags = PARAM_NAME_FLAGS.get(key_name.as_str()).ok_or_else(|| SessionConfigError::UnrecognizedEntry(key_name.to_string()))?;
344 Ok(flags.no_alter_sys)
345 }
346 }
347 }
348}