prost_helpers/
generate.rs1use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
16use quote::quote;
17use syn::ext::IdentExt;
18use syn::spanned::Spanned;
19use syn::{
20 AttrStyle, Attribute, Error, Expr, ExprLit, Field, GenericArgument, Lit, Meta, Path,
21 PathArguments, PathSegment, Result, Type,
22};
23
24fn extract_type_from_option(option_segment: &PathSegment) -> Type {
25 let generic_arg = match &option_segment.arguments {
26 PathArguments::AngleBracketed(params) => params.args.first().unwrap(),
27 _ => panic!("Option has no angle bracket"),
28 };
29 match generic_arg {
30 GenericArgument::Type(inner_ty) => inner_ty.clone(),
31 _ => panic!("Option's argument must be a type"),
32 }
33}
34
35fn extract_enum_type_from_field(field: &Field) -> Option<Type> {
43 use syn::Token;
44 use syn::punctuated::Punctuated;
45
46 match &field.ty {
48 Type::Path(path) => {
49 if !path.path.segments.first()?.ident.eq("i32") {
50 return None;
51 }
52 }
53 _ => return None,
54 };
55
56 let attr = field
58 .attrs
59 .iter()
60 .find(|attr| attr.path().is_ident("prost"))?;
61
62 let args = attr
64 .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
65 .unwrap();
66
67 let enum_type = args
68 .iter()
69 .map(|meta| match meta {
70 Meta::NameValue(kv) if kv.path.is_ident("enumeration") => {
71 match &kv.value {
72 Expr::Lit(ExprLit {
74 lit: Lit::Str(str), ..
75 }) => Some(str.to_owned()),
76 _ => None,
77 }
78 }
79 _ => None,
80 })
81 .next()??;
82
83 Some(syn::parse_str::<Type>(&enum_type.value()).unwrap())
84}
85
86fn is_deprecated(field: &Field) -> bool {
87 field.attrs.iter().any(|attr| match &attr.meta {
88 Meta::Path(path) => path.is_ident("deprecated"),
89 _ => false,
90 })
91}
92
93pub fn implement(field: &Field) -> Result<TokenStream2> {
94 let field_name = field
95 .clone()
96 .ident
97 .ok_or_else(|| Error::new(field.span(), "Expected the field to have a name"))?;
98
99 let getter_fn_name = Ident::new(&format!("get_{}", field_name.unraw()), Span::call_site());
100 let is_deprecated = is_deprecated(field);
101
102 let attr_list: Vec<Attribute> = if is_deprecated {
103 vec![Attribute {
104 pound_token: Default::default(),
105 style: AttrStyle::Outer,
106 bracket_token: Default::default(),
107 meta: Meta::from(Path::from(Ident::new("deprecated", Span::call_site()))),
108 }]
109 } else {
110 vec![]
111 };
112
113 if let Some(enum_type) = extract_enum_type_from_field(field) {
114 return Ok(quote! {
115 #(#attr_list)*
116 #[inline(always)]
117 pub fn #getter_fn_name(&self) -> std::result::Result<#enum_type, crate::PbFieldNotFound> {
118 if self.#field_name.eq(&0) {
119 return Err(crate::PbFieldNotFound(stringify!(#field_name)));
120 }
121 #enum_type::from_i32(self.#field_name).ok_or_else(|| crate::PbFieldNotFound(stringify!(#field_name)))
122 }
123 });
124 };
125
126 let ty = field.ty.clone();
127 if let Type::Path(ref type_path) = ty {
128 let data_type = type_path.path.segments.last().unwrap();
129 if data_type.ident == "Option" {
130 let ty = extract_type_from_option(data_type);
132 return Ok(quote! {
133 #(#attr_list)*
134 #[inline(always)]
135 pub fn #getter_fn_name(&self) -> std::result::Result<&#ty, crate::PbFieldNotFound> {
136 self.#field_name.as_ref().ok_or_else(|| crate::PbFieldNotFound(stringify!(#field_name)))
137 }
138 });
139 } else if ["u32", "u64", "f32", "f64", "i32", "i64", "bool"]
140 .contains(&data_type.ident.to_string().as_str())
141 {
142 return Ok(quote! {
144 #(#attr_list)*
145 #[inline(always)]
146 pub fn #getter_fn_name(&self) -> #ty {
147 self.#field_name
148 }
149 });
150 }
151 }
152
153 Ok(quote! {
154 #(#attr_list)*
155 #[inline(always)]
156 pub fn #getter_fn_name(&self) -> &#ty {
157 &self.#field_name
158 }
159 })
160}