risingwave_expr_macro/
parse.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Parse the tokens of the macro.
16
17use proc_macro_error::abort;
18use quote::ToTokens;
19use syn::parse::{Parse, ParseStream};
20use syn::spanned::Spanned;
21use syn::{LitStr, Token, parse_quote};
22
23use super::*;
24
25impl Parse for FunctionAttr {
26    /// Parse the attribute of the function macro.
27    fn parse(input: ParseStream<'_>) -> Result<Self> {
28        let mut parsed = Self::default();
29
30        let sig = input.parse::<LitStr>()?;
31        let sig_str = sig.value();
32        let (name_args, ret) = match sig_str.split_once("->") {
33            Some((name_args, ret)) => (name_args, ret),
34            None => (sig_str.as_str(), "void"),
35        };
36        let (name, args) = name_args
37            .split_once('(')
38            .ok_or_else(|| Error::new_spanned(&sig, "expected '('"))?;
39        let args = args.trim_start().trim_end_matches([')', ' ']);
40        let (is_table_function, ret) = match ret.trim_start().strip_prefix("setof") {
41            Some(s) => (true, s),
42            None => (false, ret),
43        };
44        parsed.name = name.trim().to_owned();
45        parsed.args = if args.is_empty() {
46            vec![]
47        } else {
48            args.split(',').map(|s| s.trim().to_owned()).collect()
49        };
50        parsed.ret = ret.trim().to_owned();
51        parsed.is_table_function = is_table_function;
52
53        if input.parse::<Token![,]>().is_err() {
54            return Ok(parsed);
55        }
56
57        let metas = input.parse_terminated(syn::Meta::parse, Token![,])?;
58        for meta in metas {
59            let get_value = || {
60                let kv = meta.require_name_value()?;
61                let syn::Expr::Lit(lit) = &kv.value else {
62                    return Err(Error::new(kv.value.span(), "expected literal"));
63                };
64                let syn::Lit::Str(lit) = &lit.lit else {
65                    return Err(Error::new(kv.value.span(), "expected string literal"));
66                };
67                Ok(lit.value())
68            };
69            if meta.path().is_ident("batch_fn") {
70                parsed.batch_fn = Some(get_value()?);
71            } else if meta.path().is_ident("state") {
72                parsed.state = Some(get_value()?);
73            } else if meta.path().is_ident("init_state") {
74                parsed.init_state = Some(get_value()?);
75            } else if meta.path().is_ident("prebuild") {
76                parsed.prebuild = Some(get_value()?);
77            } else if meta.path().is_ident("type_infer") {
78                parsed.type_infer = Some(get_value()?);
79            } else if meta.path().is_ident("generic") {
80                parsed.generic = Some(get_value()?);
81            } else if meta.path().is_ident("volatile") {
82                parsed.volatile = true;
83            } else if meta.path().is_ident("deprecated") || meta.path().is_ident("internal") {
84                parsed.deprecated = true;
85            } else if meta.path().is_ident("rewritten") {
86                parsed.rewritten = true;
87            } else if meta.path().is_ident("append_only") {
88                parsed.append_only = true;
89            } else {
90                return Err(Error::new(
91                    meta.span(),
92                    format!("invalid property: {:?}", meta.path()),
93                ));
94            }
95        }
96        Ok(parsed)
97    }
98}
99
100impl Parse for UserFunctionAttr {
101    fn parse(input: ParseStream<'_>) -> Result<Self> {
102        let itemfn: syn::ItemFn = input.parse()?;
103        Ok(UserFunctionAttr::from(&itemfn.sig))
104    }
105}
106
107impl From<&syn::Signature> for UserFunctionAttr {
108    fn from(sig: &syn::Signature) -> Self {
109        let (return_type_kind, iterator_item_kind, core_return_type) = match &sig.output {
110            syn::ReturnType::Default => (ReturnTypeKind::T, None, "()".into()),
111            syn::ReturnType::Type(_, ty) => {
112                let (kind, inner) = check_type(ty);
113                match strip_iterator(inner) {
114                    Some(ty) => {
115                        let (inner_kind, inner) = check_type(ty);
116                        (kind, Some(inner_kind), inner.to_token_stream().to_string())
117                    }
118                    None => (kind, None, inner.to_token_stream().to_string()),
119                }
120            }
121        };
122        UserFunctionAttr {
123            name: sig.ident.to_string(),
124            async_: sig.asyncness.is_some(),
125            writer_type_kind: sig.inputs.last().and_then(arg_writer_type),
126            context: sig.inputs.iter().any(arg_is_context),
127            retract: last_arg_is_retract(sig),
128            args_option: sig.inputs.iter().map(arg_is_option).collect(),
129            first_mut_ref_arg: first_mut_ref_arg(sig),
130            return_type_kind,
131            iterator_item_kind,
132            core_return_type,
133            generic: sig.generics.params.len(),
134            return_type_span: sig.output.span(),
135        }
136    }
137}
138
139impl Parse for AggregateImpl {
140    fn parse(input: ParseStream<'_>) -> Result<Self> {
141        let itemimpl: syn::ItemImpl = input.parse()?;
142        let parse_function = |name: &str| {
143            itemimpl.items.iter().find_map(|item| match item {
144                syn::ImplItem::Fn(syn::ImplItemFn { sig, .. }) if sig.ident == name => {
145                    Some(UserFunctionAttr::from(sig))
146                }
147                _ => None,
148            })
149        };
150        let self_path = itemimpl.self_ty.to_token_stream().to_string();
151        let struct_name = match self_path.split_once('<') {
152            Some((path, _)) => path.trim().into(), // remove generic parameters
153            None => self_path,
154        };
155        Ok(AggregateImpl {
156            struct_name,
157            accumulate: parse_function("accumulate").expect("expect accumulate function"),
158            retract: parse_function("retract"),
159            merge: parse_function("merge"),
160            finalize: parse_function("finalize"),
161            create_state: parse_function("create_state"),
162            encode_state: parse_function("encode_state"),
163            decode_state: parse_function("decode_state"),
164        })
165    }
166}
167
168impl Parse for AggregateFnOrImpl {
169    fn parse(input: ParseStream<'_>) -> Result<Self> {
170        // consume attributes
171        let _ = input.call(syn::Attribute::parse_outer)?;
172        if input.peek(Token![impl]) {
173            Ok(AggregateFnOrImpl::Impl(input.parse()?))
174        } else {
175            Ok(AggregateFnOrImpl::Fn(input.parse()?))
176        }
177    }
178}
179
180/// Check if the argument is a writer and return its type.
181fn arg_writer_type(arg: &syn::FnArg) -> Option<WriterTypeKind> {
182    let syn::FnArg::Typed(arg) = arg else {
183        return None;
184    };
185    let syn::Type::Reference(syn::TypeReference { elem, .. }) = arg.ty.as_ref() else {
186        return None;
187    };
188    let elem = elem.as_ref();
189    if elem == &parse_quote!(impl std::fmt::Write) {
190        Some(WriterTypeKind::FmtWrite)
191    } else if elem == &parse_quote!(impl std::io::Write) {
192        Some(WriterTypeKind::IoWrite)
193    } else if elem == &parse_quote!(impl Write) {
194        abort! { elem, "use of ambiguous `Write` trait.";
195            note = "`function!` macro can only recognize fully-qualified type.";
196            help = "Please use `std::fmt::Write` or `std::io::Write` instead."
197        };
198    } else if elem == &parse_quote!(jsonbb::Builder) {
199        Some(WriterTypeKind::JsonbbBuilder)
200    } else if elem == &parse_quote!(Builder) {
201        abort! { elem, "use of ambiguous `Builder` type.";
202            note = "`function!` macro can only recognize fully-qualified type.";
203            help = "Please use `jsonbb::Builder` instead."
204        };
205    } else if elem == &parse_quote!(impl risingwave_common::array::ListWrite) {
206        Some(WriterTypeKind::ListWrite)
207    } else {
208        None
209    }
210}
211
212/// Check if the argument is `&Context`.
213fn arg_is_context(arg: &syn::FnArg) -> bool {
214    let syn::FnArg::Typed(arg) = arg else {
215        return false;
216    };
217    let syn::Type::Reference(syn::TypeReference { elem, .. }) = arg.ty.as_ref() else {
218        return false;
219    };
220    let syn::Type::Path(path) = elem.as_ref() else {
221        return false;
222    };
223    let Some(seg) = path.path.segments.last() else {
224        return false;
225    };
226    seg.ident == "Context"
227}
228
229/// Check if the last argument is `retract: bool`.
230fn last_arg_is_retract(sig: &syn::Signature) -> bool {
231    let Some(syn::FnArg::Typed(arg)) = sig.inputs.last() else {
232        return false;
233    };
234    let syn::Pat::Ident(pat) = &*arg.pat else {
235        return false;
236    };
237    pat.ident.to_string().contains("retract")
238}
239
240/// Check if the argument is `Option`.
241fn arg_is_option(arg: &syn::FnArg) -> bool {
242    let syn::FnArg::Typed(arg) = arg else {
243        return false;
244    };
245    let syn::Type::Path(path) = arg.ty.as_ref() else {
246        return false;
247    };
248    let Some(seg) = path.path.segments.last() else {
249        return false;
250    };
251    seg.ident == "Option"
252}
253
254/// Returns `T` if the first argument (except `self`) is `&mut T`.
255fn first_mut_ref_arg(sig: &syn::Signature) -> Option<String> {
256    let arg = match sig.inputs.first()? {
257        syn::FnArg::Typed(arg) => arg,
258        syn::FnArg::Receiver(_) => match sig.inputs.iter().nth(1)? {
259            syn::FnArg::Typed(arg) => arg,
260            _ => return None,
261        },
262    };
263    let syn::Type::Reference(syn::TypeReference {
264        elem,
265        mutability: Some(_),
266        ..
267    }) = arg.ty.as_ref()
268    else {
269        return None;
270    };
271    Some(elem.to_token_stream().to_string())
272}
273
274/// Check the return type.
275fn check_type(ty: &syn::Type) -> (ReturnTypeKind, &syn::Type) {
276    if let Some(inner) = strip_outer_type(ty, "Result") {
277        if let Some(inner) = strip_outer_type(inner, "Option") {
278            (ReturnTypeKind::ResultOption, inner)
279        } else {
280            (ReturnTypeKind::Result, inner)
281        }
282    } else if let Some(inner) = strip_outer_type(ty, "Option") {
283        (ReturnTypeKind::Option, inner)
284    } else if let Some(inner) = strip_outer_type(ty, "DatumRef") {
285        (ReturnTypeKind::Option, inner)
286    } else {
287        (ReturnTypeKind::T, ty)
288    }
289}
290
291/// Check if the type is `type_<T>` and return `T`.
292fn strip_outer_type<'a>(ty: &'a syn::Type, type_: &str) -> Option<&'a syn::Type> {
293    let syn::Type::Path(path) = ty else {
294        return None;
295    };
296    let seg = path.path.segments.last()?;
297    if seg.ident != type_ {
298        return None;
299    }
300    let syn::PathArguments::AngleBracketed(args) = &seg.arguments else {
301        return None;
302    };
303    let Some(syn::GenericArgument::Type(ty)) = args.args.first() else {
304        return None;
305    };
306    Some(ty)
307}
308
309/// Check if the type is `impl Iterator<Item = T>` and return `T`.
310fn strip_iterator(ty: &syn::Type) -> Option<&syn::Type> {
311    let syn::Type::ImplTrait(impl_trait) = ty else {
312        return None;
313    };
314    let syn::TypeParamBound::Trait(trait_bound) = impl_trait.bounds.first()? else {
315        return None;
316    };
317    let segment = trait_bound.path.segments.last().unwrap();
318    if segment.ident != "Iterator" {
319        return None;
320    }
321    let syn::PathArguments::AngleBracketed(angle_bracketed) = &segment.arguments else {
322        return None;
323    };
324    for arg in &angle_bracketed.args {
325        if let syn::GenericArgument::AssocType(b) = arg
326            && b.ident == "Item"
327        {
328            return Some(&b.ty);
329        }
330    }
331    None
332}