prost_helpers/
lib.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![cfg_attr(coverage, feature(coverage_attribute))]
16#![feature(iterator_try_collect)]
17
18use proc_macro::TokenStream;
19use proc_macro2::{Span, TokenStream as TokenStream2};
20use quote::{format_ident, quote};
21use syn::{Data, DataStruct, DeriveInput, Result, parse_macro_input};
22
23mod generate;
24
25/// This attribute will be placed before any pb types, including messages and enums.
26/// See `prost/helpers/README.md` for more details.
27#[cfg_attr(coverage, coverage(off))]
28#[proc_macro_derive(AnyPB)]
29pub fn any_pb(input: TokenStream) -> TokenStream {
30    // Parse the string representation
31    let ast = parse_macro_input!(input as DeriveInput);
32
33    match produce(&ast) {
34        Ok(tokens) => tokens.into(),
35        Err(e) => e.to_compile_error().into(),
36    }
37}
38
39// Procedure macros can not be tested from the same crate.
40#[cfg_attr(coverage, coverage(off))]
41fn produce(ast: &DeriveInput) -> Result<TokenStream2> {
42    let name = &ast.ident;
43
44    // Is it a struct?
45    let struct_get = if let syn::Data::Struct(DataStruct { ref fields, .. }) = ast.data {
46        let generated: Vec<_> = fields.iter().map(generate::implement).try_collect()?;
47        quote! {
48            impl #name {
49                #(#generated)*
50            }
51        }
52    } else {
53        // Do nothing.
54        quote! {}
55    };
56
57    // Add a `Pb`-prefixed alias for all types.
58    let pb_alias = {
59        let pb_name = format_ident!("Pb{name}");
60        let doc = format!("Alias for [`{name}`].");
61        quote! {
62            #[doc = #doc]
63            pub type #pb_name = #name;
64        }
65    };
66
67    Ok(quote! {
68        #pb_alias
69        #struct_get
70    })
71}
72
73#[cfg_attr(coverage, coverage(off))]
74#[proc_macro_derive(Version)]
75pub fn version(input: TokenStream) -> TokenStream {
76    fn version_inner(ast: &DeriveInput) -> syn::Result<TokenStream2> {
77        let last_variant = match &ast.data {
78            Data::Enum(v) => v.variants.iter().next_back().ok_or_else(|| {
79                syn::Error::new(
80                    Span::call_site(),
81                    "This macro requires at least one variant in the enum.",
82                )
83            })?,
84            _ => {
85                return Err(syn::Error::new(
86                    Span::call_site(),
87                    "This macro only supports enums.",
88                ));
89            }
90        };
91
92        let enum_name = &ast.ident;
93        let last_variant_name = &last_variant.ident;
94
95        Ok(quote! {
96            impl #enum_name {
97                pub const LATEST: Self = Self::#last_variant_name;
98            }
99        })
100    }
101
102    let ast = parse_macro_input!(input as DeriveInput);
103
104    match version_inner(&ast) {
105        Ok(tokens) => tokens.into(),
106        Err(e) => e.to_compile_error().into(),
107    }
108}