prost_helpers/
generate.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
16use quote::quote;
17use syn::ext::IdentExt;
18use syn::spanned::Spanned;
19use syn::{
20    AttrStyle, Attribute, Error, Expr, ExprLit, Field, GenericArgument, Lit, Meta, Path,
21    PathArguments, PathSegment, Result, Type,
22};
23
24fn extract_type_from_option(option_segment: &PathSegment) -> Type {
25    let generic_arg = match &option_segment.arguments {
26        PathArguments::AngleBracketed(params) => params.args.first().unwrap(),
27        _ => panic!("Option has no angle bracket"),
28    };
29    match generic_arg {
30        GenericArgument::Type(inner_ty) => inner_ty.clone(),
31        _ => panic!("Option's argument must be a type"),
32    }
33}
34
35/// For example:
36/// ```ignore
37/// #[prost(enumeration = "data_type::TypeName", tag = "1")]
38/// pub type_name: i32,
39/// ```
40///
41/// Returns `data_type::TypeName`.
42fn extract_enum_type_from_field(field: &Field) -> Option<Type> {
43    use syn::Token;
44    use syn::punctuated::Punctuated;
45
46    // The type must be i32.
47    match &field.ty {
48        Type::Path(path) => {
49            if !path.path.segments.first()?.ident.eq("i32") {
50                return None;
51            }
52        }
53        _ => return None,
54    };
55
56    // find attribute `#[prost(...)]`
57    let attr = field
58        .attrs
59        .iter()
60        .find(|attr| attr.path().is_ident("prost"))?;
61
62    // `enumeration = "data_type::TypeName", tag = "1"`
63    let args = attr
64        .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
65        .unwrap();
66
67    let enum_type = args
68        .iter()
69        .map(|meta| match meta {
70            Meta::NameValue(kv) if kv.path.is_ident("enumeration") => {
71                match &kv.value {
72                    // data_type::TypeName
73                    Expr::Lit(ExprLit {
74                        lit: Lit::Str(str), ..
75                    }) => Some(str.to_owned()),
76                    _ => None,
77                }
78            }
79            _ => None,
80        })
81        .next()??;
82
83    Some(syn::parse_str::<Type>(&enum_type.value()).unwrap())
84}
85
86fn is_deprecated(field: &Field) -> bool {
87    field.attrs.iter().any(|attr| match &attr.meta {
88        Meta::Path(path) => path.is_ident("deprecated"),
89        _ => false,
90    })
91}
92
93pub fn implement(field: &Field) -> Result<TokenStream2> {
94    let field_name = field
95        .clone()
96        .ident
97        .ok_or_else(|| Error::new(field.span(), "Expected the field to have a name"))?;
98
99    let getter_fn_name = Ident::new(&format!("get_{}", field_name.unraw()), Span::call_site());
100    let is_deprecated = is_deprecated(field);
101
102    let attr_list: Vec<Attribute> = if is_deprecated {
103        vec![Attribute {
104            pound_token: Default::default(),
105            style: AttrStyle::Outer,
106            bracket_token: Default::default(),
107            meta: Meta::from(Path::from(Ident::new("deprecated", Span::call_site()))),
108        }]
109    } else {
110        vec![]
111    };
112
113    if let Some(enum_type) = extract_enum_type_from_field(field) {
114        return Ok(quote! {
115            #(#attr_list)*
116            #[inline(always)]
117            pub fn #getter_fn_name(&self) -> std::result::Result<#enum_type, crate::PbFieldNotFound> {
118                if self.#field_name.eq(&0) {
119                    return Err(crate::PbFieldNotFound(stringify!(#field_name)));
120                }
121                #enum_type::from_i32(self.#field_name).ok_or_else(|| crate::PbFieldNotFound(stringify!(#field_name)))
122            }
123        });
124    };
125
126    let ty = field.ty.clone();
127    if let Type::Path(ref type_path) = ty {
128        let data_type = type_path.path.segments.last().unwrap();
129        if data_type.ident == "Option" {
130            // ::core::option::Option<Foo>
131            let ty = extract_type_from_option(data_type);
132            return Ok(quote! {
133                #(#attr_list)*
134                #[inline(always)]
135                pub fn #getter_fn_name(&self) -> std::result::Result<&#ty, crate::PbFieldNotFound> {
136                    self.#field_name.as_ref().ok_or_else(|| crate::PbFieldNotFound(stringify!(#field_name)))
137                }
138            });
139        } else if ["u32", "u64", "f32", "f64", "i32", "i64", "bool"]
140            .contains(&data_type.ident.to_string().as_str())
141        {
142            // Primitive types. Return value instead of reference.
143            return Ok(quote! {
144                #(#attr_list)*
145                #[inline(always)]
146                pub fn #getter_fn_name(&self) -> #ty {
147                    self.#field_name
148                }
149            });
150        }
151    }
152
153    Ok(quote! {
154        #(#attr_list)*
155        #[inline(always)]
156        pub fn #getter_fn_name(&self) -> &#ty {
157            &self.#field_name
158        }
159    })
160}