risingwave_fields_derive/
lib.rs1use proc_macro2::TokenStream;
16use quote::quote;
17use syn::{Data, DeriveInput, Result};
18
19#[proc_macro_derive(Fields, attributes(primary_key, fields))]
20pub fn fields(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
21 inner(tokens.into()).into()
22}
23
24fn inner(tokens: TokenStream) -> TokenStream {
25 match r#gen(tokens) {
26 Ok(tokens) => tokens,
27 Err(err) => err.to_compile_error(),
28 }
29}
30
31fn r#gen(tokens: TokenStream) -> Result<TokenStream> {
32 let input: DeriveInput = syn::parse2(tokens)?;
33
34 let ident = &input.ident;
35 if !input.generics.params.is_empty() {
36 return Err(syn::Error::new_spanned(
37 input.generics,
38 "generics are not supported",
39 ));
40 }
41
42 let Data::Struct(struct_) = &input.data else {
43 return Err(syn::Error::new_spanned(
44 input.ident,
45 "only structs are supported",
46 ));
47 };
48
49 let style = get_style(&input);
50 if let Some(style) = &style {
51 if !["Title Case", "TITLE CASE", "snake_case"].contains(&style.value().as_str()) {
52 return Err(syn::Error::new_spanned(
53 style,
54 "only `Title Case`, `TITLE CASE`, and `snake_case` are supported",
55 ));
56 }
57 }
58
59 let fields_rw: Vec<TokenStream> = struct_
60 .fields
61 .iter()
62 .map(|field| {
63 let mut name = field.ident.as_ref().expect("field no name").to_string();
64 if name.starts_with("r#") {
66 name = name[2..].to_string();
67 }
68 match style.as_ref().map_or(String::new(), |f| f.value()).as_str() {
70 "Title Case" => name = to_title_case(&name),
71 "TITLE CASE" => name = to_title_case(&name).to_uppercase(),
72 _ => {}
73 }
74 let ty = &field.ty;
75 quote! {
76 (#name, <#ty as ::risingwave_common::types::WithDataType>::default_data_type())
77 }
78 })
79 .collect();
80 let names = struct_
81 .fields
82 .iter()
83 .map(|field| field.ident.as_ref().expect("field no name"))
84 .collect::<Vec<_>>();
85 let primary_key = get_primary_key(&input).map_or_else(
86 || quote! { None },
87 |indices| {
88 quote! { Some(&[#(#indices),*]) }
89 },
90 );
91
92 Ok(quote! {
93 impl ::risingwave_common::types::Fields for #ident {
94 const PRIMARY_KEY: Option<&'static [usize]> = #primary_key;
95
96 fn fields() -> Vec<(&'static str, ::risingwave_common::types::DataType)> {
97 vec![#(#fields_rw),*]
98 }
99 fn into_owned_row(self) -> ::risingwave_common::row::OwnedRow {
100 ::risingwave_common::row::OwnedRow::new(vec![#(
101 ::risingwave_common::types::ToOwnedDatum::to_owned_datum(self.#names)
102 ),*])
103 }
104 }
105 impl From<#ident> for ::risingwave_common::types::ScalarImpl {
106 fn from(v: #ident) -> Self {
107 ::risingwave_common::types::StructValue::new(vec![#(
108 ::risingwave_common::types::ToOwnedDatum::to_owned_datum(v.#names)
109 ),*]).into()
110 }
111 }
112 })
113}
114
115fn get_primary_key(input: &syn::DeriveInput) -> Option<Vec<usize>> {
117 let syn::Data::Struct(struct_) = &input.data else {
118 return None;
119 };
120 let composite = input.attrs.iter().find_map(|attr| match &attr.meta {
122 syn::Meta::List(list) if list.path.is_ident("primary_key") => Some(&list.tokens),
123 _ => None,
124 });
125 if let Some(keys) = composite {
126 let index = |name: &str| {
127 struct_
128 .fields
129 .iter()
130 .position(|f| f.ident.as_ref().is_some_and(|i| i == name))
131 .expect("primary key not found")
132 };
133 return Some(
134 keys.to_string()
135 .split(',')
136 .map(|s| s.trim())
137 .filter(|s| !s.is_empty())
138 .map(index)
139 .collect(),
140 );
141 }
142 for (i, field) in struct_.fields.iter().enumerate() {
144 for attr in &field.attrs {
145 if matches!(&attr.meta, syn::Meta::Path(path) if path.is_ident("primary_key")) {
146 return Some(vec![i]);
147 }
148 }
149 }
150 None
151}
152
153fn get_style(input: &syn::DeriveInput) -> Option<syn::LitStr> {
155 let style = input.attrs.iter().find_map(|attr| match &attr.meta {
156 syn::Meta::List(list) if list.path.is_ident("fields") => {
157 let name_value: syn::MetaNameValue = syn::parse2(list.tokens.clone()).ok()?;
158 if name_value.path.is_ident("style") {
159 Some(name_value.value)
160 } else {
161 None
162 }
163 }
164 _ => None,
165 })?;
166 match style {
167 syn::Expr::Lit(lit) => match lit.lit {
168 syn::Lit::Str(s) => Some(s),
169 _ => None,
170 },
171 _ => None,
172 }
173}
174
175fn to_title_case(s: &str) -> String {
177 let mut title = String::new();
178 let mut next_upper = true;
179 for c in s.chars() {
180 if c == '_' {
181 title.push(' ');
182 next_upper = true;
183 } else if next_upper {
184 title.push(c.to_uppercase().next().unwrap());
185 next_upper = false;
186 } else {
187 title.push(c);
188 }
189 }
190 title
191}
192
193#[cfg(test)]
194mod tests {
195 use indoc::indoc;
196 use proc_macro2::TokenStream;
197 use syn::File;
198
199 fn pretty_print(output: TokenStream) -> String {
200 let output: File = syn::parse2(output).unwrap();
201 prettyplease::unparse(&output)
202 }
203
204 fn do_test(code: &str, expected_path: &str) {
205 let input: TokenStream = str::parse(code).unwrap();
206
207 let output = super::r#gen(input).unwrap();
208
209 let output = pretty_print(output);
210
211 let expected = expect_test::expect_file![expected_path];
212
213 expected.assert_eq(&output);
214 }
215
216 #[test]
217 fn test_gen() {
218 let code = indoc! {r#"
219 #[derive(Fields)]
220 #[primary_key(v2, v1)]
221 struct Data {
222 v1: i16,
223 v2: std::primitive::i32,
224 v3: bool,
225 v4: Serial,
226 r#type: i32,
227 }
228 "#};
229
230 do_test(code, "gen/test_output.rs");
231 }
232
233 #[test]
234 fn test_no_pk() {
235 let code = indoc! {r#"
236 #[derive(Fields)]
237 struct Data {
238 v1: i16,
239 v2: String,
240 }
241 "#};
242
243 do_test(code, "gen/test_no_pk.rs");
244 }
245
246 #[test]
247 fn test_empty_pk() {
248 let code = indoc! {r#"
249 #[derive(Fields)]
250 #[primary_key()]
251 struct Data {
252 v1: i16,
253 v2: String,
254 }
255 "#};
256
257 do_test(code, "gen/test_empty_pk.rs");
258 }
259}