risingwave_common_proc_macro/
session_config.rs1use bae::FromAttributes;
16use proc_macro_error::{OptionExt, ResultExt, abort};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote, quote_spanned};
19use syn::DeriveInput;
20
21#[derive(FromAttributes)]
22struct Parameter {
23 pub rename: Option<syn::LitStr>,
24 pub alias: Option<syn::Expr>,
25 pub default: syn::Expr,
26 pub flags: Option<syn::LitStr>,
27 pub check_hook: Option<syn::Expr>,
28}
29
30pub(crate) fn derive_config(input: DeriveInput) -> TokenStream {
31 let syn::Data::Struct(syn::DataStruct { fields, .. }) = input.data else {
32 abort!(input, "Only struct is supported");
33 };
34
35 let mut default_fields = vec![];
36 let mut struct_impl_set = vec![];
37 let mut struct_impl_get = vec![];
38 let mut struct_impl_reset = vec![];
39 let mut set_match_branches = vec![];
40 let mut get_match_branches = vec![];
41 let mut reset_match_branches = vec![];
42 let mut show_all_list = vec![];
43 let mut list_all_list = vec![];
44 let mut alias_to_entry_name_branches = vec![];
45 let mut entry_name_flags = vec![];
46 let mut session_init_fields = vec![];
49 let mut session_init_entries = vec![];
50
51 for field in fields {
52 let field_ident = field.ident.expect_or_abort("Field need to be named");
53 let ty = field.ty;
54
55 let mut doc_list = vec![];
56 for attr in &field.attrs {
57 if attr.path.is_ident("doc") {
58 let meta = attr.parse_meta().expect_or_abort("Failed to parse meta");
59 if let syn::Meta::NameValue(val) = meta
60 && let syn::Lit::Str(desc) = val.lit
61 {
62 doc_list.push(desc.value().trim().to_owned());
63 }
64 }
65 }
66
67 let description: TokenStream = format!("r#\"{}\"#", doc_list.join(" ")).parse().unwrap();
68
69 let attr =
70 Parameter::from_attributes(&field.attrs).expect_or_abort("Failed to parse attribute");
71 let Parameter {
72 rename,
73 alias,
74 default,
75 flags,
76 check_hook: check_hook_name,
77 } = attr;
78
79 let entry_name = if let Some(rename) = rename {
80 if !(rename.value().is_ascii() && rename.value().to_ascii_lowercase() == rename.value())
81 {
82 abort!(rename, "Expect `rename` to be an ascii lower case string");
83 }
84 quote! {#rename}
85 } else {
86 let ident = format_ident!("{}", field_ident.to_string().to_lowercase());
87 quote! {stringify!(#ident)}
88 };
89
90 if let Some(alias) = alias {
91 alias_to_entry_name_branches.push(quote! {
92 #alias => #entry_name.to_string(),
93 })
94 }
95
96 let flags = flags.map(|f| f.value()).unwrap_or_default();
97 let flags: Vec<_> = flags.split('|').map(|str| str.trim()).collect();
98
99 default_fields.push(quote_spanned! {
100 field_ident.span()=>
101 #field_ident: #default.into(),
102 });
103
104 let set_func_name = format_ident!("set_{}_str", field_ident);
105 let set_t_func_name = format_ident!("set_{}", field_ident);
106 let set_t_inner_func_name = format_ident!("set_{}_inner", field_ident);
107 let set_t_func_doc: TokenStream =
108 format!("r#\"Set parameter {} by a typed value.\"#", entry_name)
109 .parse()
110 .unwrap();
111 let set_func_doc: TokenStream = format!("r#\"Set parameter {} by a string.\"#", entry_name)
112 .parse()
113 .unwrap();
114
115 let gen_set_func_name = if flags.contains(&"SETTER") {
116 set_t_inner_func_name.clone()
117 } else {
118 set_t_func_name.clone()
119 };
120
121 let check_hook = if let Some(check_hook_name) = check_hook_name {
122 quote! {
123 #check_hook_name(&val).map_err(|e| {
124 SessionConfigError::InvalidValue {
125 entry: #entry_name,
126 value: val.to_string(),
127 source: anyhow::anyhow!(e),
128 }
129 })?;
130 }
131 } else {
132 quote! {}
133 };
134
135 let report_hook = if flags.contains(&"REPORT") {
136 quote! {
137 if self.#field_ident != val {
138 reporter.report_status(#entry_name, val.to_string());
139 }
140 }
141 } else {
142 quote! {}
143 };
144
145 let parse = if quote!(#ty).to_string() == "bool" {
147 quote!(risingwave_common::cast::str_to_bool)
148 } else {
149 quote!(<#ty as ::std::str::FromStr>::from_str)
150 };
151
152 struct_impl_set.push(quote! {
153 #[doc = #set_func_doc]
154 pub fn #set_func_name(
155 &mut self,
156 val: &str,
157 reporter: &mut impl ConfigReporter
158 ) -> SessionConfigResult<String> {
159 let val_t = #parse(val).map_err(|e| {
160 SessionConfigError::InvalidValue {
161 entry: #entry_name,
162 value: val.to_string(),
163 source: anyhow::anyhow!(e),
164 }
165 })?;
166
167 self.#set_t_func_name(val_t, reporter).map(|val| val.to_string())
168 }
169
170 #[doc = #set_t_func_doc]
171 pub fn #gen_set_func_name(
172 &mut self,
173 val: #ty,
174 reporter: &mut impl ConfigReporter
175 ) -> SessionConfigResult<#ty> {
176 #check_hook
177 #report_hook
178
179 self.#field_ident = val.clone();
180 Ok(val)
181 }
182
183 });
184
185 let reset_func_name = format_ident!("reset_{}", field_ident);
186 struct_impl_reset.push(quote! {
187
188 #[allow(clippy::useless_conversion)]
189 pub fn #reset_func_name(&mut self, reporter: &mut impl ConfigReporter) -> String {
190 let val = #default;
191 #report_hook
192 self.#field_ident = val.into();
193 self.#field_ident.to_string()
194 }
195 });
196
197 let get_func_name = format_ident!("{}_str", field_ident);
198 let get_t_func_name = format_ident!("{}", field_ident);
199 let get_func_doc: TokenStream =
200 format!("r#\"Get a value string of parameter {} \"#", entry_name)
201 .parse()
202 .unwrap();
203 let get_t_func_doc: TokenStream =
204 format!("r#\"Get a typed value of parameter {} \"#", entry_name)
205 .parse()
206 .unwrap();
207
208 struct_impl_get.push(quote! {
209 #[doc = #get_func_doc]
210 pub fn #get_func_name(&self) -> String {
211 self.#get_t_func_name().to_string()
212 }
213
214 #[doc = #get_t_func_doc]
215 pub fn #get_t_func_name(&self) -> #ty {
216 self.#field_ident.clone()
217 }
218
219 });
220
221 get_match_branches.push(quote! {
222 #entry_name => Ok(self.#get_func_name()),
223 });
224
225 set_match_branches.push(quote! {
226 #entry_name => self.#set_func_name(&value, reporter),
227 });
228
229 reset_match_branches.push(quote! {
230 #entry_name => Ok(self.#reset_func_name(reporter)),
231 });
232
233 let var_info = quote! {
234 VariableInfo {
235 name: #entry_name.to_string(),
236 setting: self.#field_ident.to_string(),
237 description : #description.to_string(),
238 },
239 };
240 list_all_list.push(var_info.clone());
241
242 let no_show_all = flags.contains(&"NO_SHOW_ALL");
243 let no_show_all_flag: TokenStream = no_show_all.to_string().parse().unwrap();
244 if !no_show_all {
245 show_all_list.push(var_info);
246 }
247
248 let no_alter_sys_flag: TokenStream =
249 flags.contains(&"NO_ALTER_SYS").to_string().parse().unwrap();
250
251 entry_name_flags.push(
252 quote! {
253 (#entry_name, ParamFlags {no_show_all: #no_show_all_flag, no_alter_sys: #no_alter_sys_flag})
254 }
255 );
256
257 if flags.contains(&"SESSION_INIT") {
262 let doc_string = doc_list.join(" ");
263 session_init_fields.push(quote! {
264 #[doc = #doc_string]
265 #[serde(default, with = "crate::config::none_as_empty_string")]
266 pub #field_ident: Option<String>,
267 });
268 session_init_entries.push(quote! {
269 (#entry_name, &self.#field_ident),
270 });
271 }
272 }
273
274 let struct_ident = input.ident;
275 quote! {
276 #[derive(Clone, Debug, Default, Serialize, Deserialize, ConfigDoc, PartialEq)]
289 #[serde(deny_unknown_fields)]
290 pub struct SessionInitConfig {
291 #(#session_init_fields)*
292 }
293
294 impl SessionInitConfig {
295 pub fn entries(&self) -> Vec<(&'static str, &str)> {
298 [
299 #(#session_init_entries)*
300 ]
301 .into_iter()
302 .filter_map(|(name, value): (&'static str, &Option<String>)| {
303 value.as_deref().map(|value| (name, value))
304 })
305 .collect()
306 }
307 }
308
309 use std::collections::HashMap;
310 use std::sync::LazyLock;
311 static PARAM_NAME_FLAGS: LazyLock<HashMap<&'static str, ParamFlags>> = LazyLock::new(|| HashMap::from([#(#entry_name_flags, )*]));
312
313 struct ParamFlags {
314 no_show_all: bool,
315 no_alter_sys: bool,
316 }
317
318 impl Default for #struct_ident {
319 #[allow(clippy::useless_conversion)]
320 fn default() -> Self {
321 Self {
322 #(#default_fields)*
323 }
324 }
325 }
326
327 impl #struct_ident {
328 fn new() -> Self {
329 Default::default()
330 }
331
332 pub fn alias_to_entry_name(key_name: &str) -> String {
333 let key_name = key_name.to_ascii_lowercase();
334 match key_name.as_str() {
335 #(#alias_to_entry_name_branches)*
336 _ => key_name,
337 }
338 }
339
340 #(#struct_impl_get)*
341
342 #(#struct_impl_set)*
343
344 #(#struct_impl_reset)*
345
346 pub fn set(&mut self, key_name: &str, value: String, reporter: &mut impl ConfigReporter) -> SessionConfigResult<String> {
348 let key_name = Self::alias_to_entry_name(key_name);
349 match key_name.as_ref() {
350 #(#set_match_branches)*
351 _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())),
352 }
353 }
354
355 pub fn get(&self, key_name: &str) -> SessionConfigResult<String> {
357 let key_name = Self::alias_to_entry_name(key_name);
358 match key_name.as_ref() {
359 #(#get_match_branches)*
360 _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())),
361 }
362 }
363
364 pub fn reset(&mut self, key_name: &str, reporter: &mut impl ConfigReporter) -> SessionConfigResult<String> {
366 let key_name = Self::alias_to_entry_name(key_name);
367 match key_name.as_ref() {
368 #(#reset_match_branches)*
369 _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())),
370 }
371 }
372
373 pub fn show_all(&self) -> Vec<VariableInfo> {
375 vec![
376 #(#show_all_list)*
377 ]
378 }
379
380 pub fn list_all(&self) -> Vec<VariableInfo> {
382 vec![
383 #(#list_all_list)*
384 ]
385 }
386
387 pub fn contains_param(key_name: &str) -> bool {
389 let key_name = Self::alias_to_entry_name(key_name);
390 PARAM_NAME_FLAGS.contains_key(key_name.as_str())
391 }
392
393 pub fn check_no_alter_sys(key_name: &str) -> SessionConfigResult<bool> {
395 let key_name = Self::alias_to_entry_name(key_name);
396 let flags = PARAM_NAME_FLAGS.get(key_name.as_str()).ok_or_else(|| SessionConfigError::UnrecognizedEntry(key_name.to_string()))?;
397 Ok(flags.no_alter_sys)
398 }
399 }
400 }
401}