risingwave_fields_derive/
lib.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use proc_macro2::TokenStream;
16use quote::quote;
17use syn::{Data, DeriveInput, Result};
18
19#[proc_macro_derive(Fields, attributes(primary_key, fields))]
20pub fn fields(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
21    inner(tokens.into()).into()
22}
23
24fn inner(tokens: TokenStream) -> TokenStream {
25    match r#gen(tokens) {
26        Ok(tokens) => tokens,
27        Err(err) => err.to_compile_error(),
28    }
29}
30
31fn r#gen(tokens: TokenStream) -> Result<TokenStream> {
32    let input: DeriveInput = syn::parse2(tokens)?;
33
34    let ident = &input.ident;
35    if !input.generics.params.is_empty() {
36        return Err(syn::Error::new_spanned(
37            input.generics,
38            "generics are not supported",
39        ));
40    }
41
42    let Data::Struct(struct_) = &input.data else {
43        return Err(syn::Error::new_spanned(
44            input.ident,
45            "only structs are supported",
46        ));
47    };
48
49    let style = get_style(&input);
50    if let Some(style) = &style {
51        if !["Title Case", "TITLE CASE", "snake_case"].contains(&style.value().as_str()) {
52            return Err(syn::Error::new_spanned(
53                style,
54                "only `Title Case`, `TITLE CASE`, and `snake_case` are supported",
55            ));
56        }
57    }
58
59    let fields_rw: Vec<TokenStream> = struct_
60        .fields
61        .iter()
62        .map(|field| {
63            let mut name = field.ident.as_ref().expect("field no name").to_string();
64            // strip leading `r#`
65            if name.starts_with("r#") {
66                name = name[2..].to_string();
67            }
68            // cast style
69            match style.as_ref().map_or(String::new(), |f| f.value()).as_str() {
70                "Title Case" => name = to_title_case(&name),
71                "TITLE CASE" => name = to_title_case(&name).to_uppercase(),
72                _ => {}
73            }
74            let ty = &field.ty;
75            quote! {
76                (#name, <#ty as ::risingwave_common::types::WithDataType>::default_data_type())
77            }
78        })
79        .collect();
80    let names = struct_
81        .fields
82        .iter()
83        .map(|field| field.ident.as_ref().expect("field no name"))
84        .collect::<Vec<_>>();
85    let primary_key = get_primary_key(&input).map_or_else(
86        || quote! { None },
87        |indices| {
88            quote! { Some(&[#(#indices),*]) }
89        },
90    );
91
92    Ok(quote! {
93        impl ::risingwave_common::types::Fields for #ident {
94            const PRIMARY_KEY: Option<&'static [usize]> = #primary_key;
95
96            fn fields() -> Vec<(&'static str, ::risingwave_common::types::DataType)> {
97                vec![#(#fields_rw),*]
98            }
99            fn into_owned_row(self) -> ::risingwave_common::row::OwnedRow {
100                ::risingwave_common::row::OwnedRow::new(vec![#(
101                    ::risingwave_common::types::ToOwnedDatum::to_owned_datum(self.#names)
102                ),*])
103            }
104        }
105        impl From<#ident> for ::risingwave_common::types::ScalarImpl {
106            fn from(v: #ident) -> Self {
107                ::risingwave_common::types::StructValue::new(vec![#(
108                    ::risingwave_common::types::ToOwnedDatum::to_owned_datum(v.#names)
109                ),*]).into()
110            }
111        }
112    })
113}
114
115/// Get primary key indices from `#[primary_key]` attribute.
116fn get_primary_key(input: &syn::DeriveInput) -> Option<Vec<usize>> {
117    let syn::Data::Struct(struct_) = &input.data else {
118        return None;
119    };
120    // find `#[primary_key(k1, k2, ...)]` on struct
121    let composite = input.attrs.iter().find_map(|attr| match &attr.meta {
122        syn::Meta::List(list) if list.path.is_ident("primary_key") => Some(&list.tokens),
123        _ => None,
124    });
125    if let Some(keys) = composite {
126        let index = |name: &str| {
127            struct_
128                .fields
129                .iter()
130                .position(|f| f.ident.as_ref().is_some_and(|i| i == name))
131                .expect("primary key not found")
132        };
133        return Some(
134            keys.to_string()
135                .split(',')
136                .map(|s| s.trim())
137                .filter(|s| !s.is_empty())
138                .map(index)
139                .collect(),
140        );
141    }
142    // find `#[primary_key]` on fields
143    for (i, field) in struct_.fields.iter().enumerate() {
144        for attr in &field.attrs {
145            if matches!(&attr.meta, syn::Meta::Path(path) if path.is_ident("primary_key")) {
146                return Some(vec![i]);
147            }
148        }
149    }
150    None
151}
152
153/// Get name style from `#[fields(style = "xxx")]` attribute.
154fn get_style(input: &syn::DeriveInput) -> Option<syn::LitStr> {
155    let style = input.attrs.iter().find_map(|attr| match &attr.meta {
156        syn::Meta::List(list) if list.path.is_ident("fields") => {
157            let name_value: syn::MetaNameValue = syn::parse2(list.tokens.clone()).ok()?;
158            if name_value.path.is_ident("style") {
159                Some(name_value.value)
160            } else {
161                None
162            }
163        }
164        _ => None,
165    })?;
166    match style {
167        syn::Expr::Lit(lit) => match lit.lit {
168            syn::Lit::Str(s) => Some(s),
169            _ => None,
170        },
171        _ => None,
172    }
173}
174
175/// Convert `snake_case` to `Title Case`.
176fn to_title_case(s: &str) -> String {
177    let mut title = String::new();
178    let mut next_upper = true;
179    for c in s.chars() {
180        if c == '_' {
181            title.push(' ');
182            next_upper = true;
183        } else if next_upper {
184            title.push(c.to_uppercase().next().unwrap());
185            next_upper = false;
186        } else {
187            title.push(c);
188        }
189    }
190    title
191}
192
193#[cfg(test)]
194mod tests {
195    use indoc::indoc;
196    use proc_macro2::TokenStream;
197    use syn::File;
198
199    fn pretty_print(output: TokenStream) -> String {
200        let output: File = syn::parse2(output).unwrap();
201        prettyplease::unparse(&output)
202    }
203
204    fn do_test(code: &str, expected_path: &str) {
205        let input: TokenStream = str::parse(code).unwrap();
206
207        let output = super::r#gen(input).unwrap();
208
209        let output = pretty_print(output);
210
211        let expected = expect_test::expect_file![expected_path];
212
213        expected.assert_eq(&output);
214    }
215
216    #[test]
217    fn test_gen() {
218        let code = indoc! {r#"
219            #[derive(Fields)]
220            #[primary_key(v2, v1)]
221            struct Data {
222                v1: i16,
223                v2: std::primitive::i32,
224                v3: bool,
225                v4: Serial,
226                r#type: i32,
227            }
228        "#};
229
230        do_test(code, "gen/test_output.rs");
231    }
232
233    #[test]
234    fn test_no_pk() {
235        let code = indoc! {r#"
236            #[derive(Fields)]
237            struct Data {
238                v1: i16,
239                v2: String,
240            }
241        "#};
242
243        do_test(code, "gen/test_no_pk.rs");
244    }
245
246    #[test]
247    fn test_empty_pk() {
248        let code = indoc! {r#"
249            #[derive(Fields)]
250            #[primary_key()]
251            struct Data {
252                v1: i16,
253                v2: String,
254            }
255        "#};
256
257        do_test(code, "gen/test_empty_pk.rs");
258    }
259}