prost_helpers/
lib.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(coverage_attribute)]
16#![feature(iterator_try_collect)]
17
18use proc_macro::TokenStream;
19use proc_macro2::{Span, TokenStream as TokenStream2};
20use quote::{format_ident, quote};
21use syn::{Data, DataStruct, DeriveInput, Result, parse_macro_input};
22
23mod generate;
24
25/// This attribute will be placed before any pb types, including messages and enums.
26/// See `prost/helpers/README.md` for more details.
27#[proc_macro_derive(AnyPB)]
28pub fn any_pb(input: TokenStream) -> TokenStream {
29    // Parse the string representation
30    let ast = parse_macro_input!(input as DeriveInput);
31
32    match produce(&ast) {
33        Ok(tokens) => tokens.into(),
34        Err(e) => e.to_compile_error().into(),
35    }
36}
37
38// Procedure macros can not be tested from the same crate.
39fn produce(ast: &DeriveInput) -> Result<TokenStream2> {
40    let name = &ast.ident;
41
42    // Is it a struct?
43    let struct_get = if let syn::Data::Struct(DataStruct { ref fields, .. }) = ast.data {
44        let generated: Vec<_> = fields.iter().map(generate::implement).try_collect()?;
45        quote! {
46            impl #name {
47                #(#generated)*
48            }
49        }
50    } else {
51        // Do nothing.
52        quote! {}
53    };
54
55    // Add a `Pb`-prefixed alias for all types.
56    // No need to add docs for this alias as rust-analyzer will forward the docs to the original type.
57    let pb_alias = {
58        let pb_name = format_ident!("Pb{name}");
59        quote! {
60            pub type #pb_name = #name;
61        }
62    };
63
64    Ok(quote! {
65        #pb_alias
66        #struct_get
67    })
68}
69
70#[proc_macro_derive(Version)]
71pub fn version(input: TokenStream) -> TokenStream {
72    fn version_inner(ast: &DeriveInput) -> syn::Result<TokenStream2> {
73        let last_variant = match &ast.data {
74            Data::Enum(v) => v.variants.iter().next_back().ok_or_else(|| {
75                syn::Error::new(
76                    Span::call_site(),
77                    "This macro requires at least one variant in the enum.",
78                )
79            })?,
80            _ => {
81                return Err(syn::Error::new(
82                    Span::call_site(),
83                    "This macro only supports enums.",
84                ));
85            }
86        };
87
88        let enum_name = &ast.ident;
89        let last_variant_name = &last_variant.ident;
90
91        Ok(quote! {
92            impl #enum_name {
93                pub const LATEST: Self = Self::#last_variant_name;
94            }
95        })
96    }
97
98    let ast = parse_macro_input!(input as DeriveInput);
99
100    match version_inner(&ast) {
101        Ok(tokens) => tokens.into(),
102        Err(e) => e.to_compile_error().into(),
103    }
104}