risingwave_expr_macro/
parse.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Parse the tokens of the macro.
16
17use quote::ToTokens;
18use syn::parse::{Parse, ParseStream};
19use syn::spanned::Spanned;
20use syn::{LitStr, Token};
21
22use super::*;
23
24impl Parse for FunctionAttr {
25    /// Parse the attribute of the function macro.
26    fn parse(input: ParseStream<'_>) -> Result<Self> {
27        let mut parsed = Self::default();
28
29        let sig = input.parse::<LitStr>()?;
30        let sig_str = sig.value();
31        let (name_args, ret) = match sig_str.split_once("->") {
32            Some((name_args, ret)) => (name_args, ret),
33            None => (sig_str.as_str(), "void"),
34        };
35        let (name, args) = name_args
36            .split_once('(')
37            .ok_or_else(|| Error::new_spanned(&sig, "expected '('"))?;
38        let args = args.trim_start().trim_end_matches([')', ' ']);
39        let (is_table_function, ret) = match ret.trim_start().strip_prefix("setof") {
40            Some(s) => (true, s),
41            None => (false, ret),
42        };
43        parsed.name = name.trim().to_owned();
44        parsed.args = if args.is_empty() {
45            vec![]
46        } else {
47            args.split(',').map(|s| s.trim().to_owned()).collect()
48        };
49        parsed.ret = ret.trim().to_owned();
50        parsed.is_table_function = is_table_function;
51
52        if input.parse::<Token![,]>().is_err() {
53            return Ok(parsed);
54        }
55
56        let metas = input.parse_terminated(syn::Meta::parse, Token![,])?;
57        for meta in metas {
58            let get_value = || {
59                let kv = meta.require_name_value()?;
60                let syn::Expr::Lit(lit) = &kv.value else {
61                    return Err(Error::new(kv.value.span(), "expected literal"));
62                };
63                let syn::Lit::Str(lit) = &lit.lit else {
64                    return Err(Error::new(kv.value.span(), "expected string literal"));
65                };
66                Ok(lit.value())
67            };
68            if meta.path().is_ident("batch_fn") {
69                parsed.batch_fn = Some(get_value()?);
70            } else if meta.path().is_ident("state") {
71                parsed.state = Some(get_value()?);
72            } else if meta.path().is_ident("init_state") {
73                parsed.init_state = Some(get_value()?);
74            } else if meta.path().is_ident("prebuild") {
75                parsed.prebuild = Some(get_value()?);
76            } else if meta.path().is_ident("type_infer") {
77                parsed.type_infer = Some(get_value()?);
78            } else if meta.path().is_ident("generic") {
79                parsed.generic = Some(get_value()?);
80            } else if meta.path().is_ident("volatile") {
81                parsed.volatile = true;
82            } else if meta.path().is_ident("deprecated") || meta.path().is_ident("internal") {
83                parsed.deprecated = true;
84            } else if meta.path().is_ident("rewritten") {
85                parsed.rewritten = true;
86            } else if meta.path().is_ident("append_only") {
87                parsed.append_only = true;
88            } else {
89                return Err(Error::new(
90                    meta.span(),
91                    format!("invalid property: {:?}", meta.path()),
92                ));
93            }
94        }
95        Ok(parsed)
96    }
97}
98
99impl Parse for UserFunctionAttr {
100    fn parse(input: ParseStream<'_>) -> Result<Self> {
101        let itemfn: syn::ItemFn = input.parse()?;
102        Ok(UserFunctionAttr::from(&itemfn.sig))
103    }
104}
105
106impl From<&syn::Signature> for UserFunctionAttr {
107    fn from(sig: &syn::Signature) -> Self {
108        let (return_type_kind, iterator_item_kind, core_return_type) = match &sig.output {
109            syn::ReturnType::Default => (ReturnTypeKind::T, None, "()".into()),
110            syn::ReturnType::Type(_, ty) => {
111                let (kind, inner) = check_type(ty);
112                match strip_iterator(inner) {
113                    Some(ty) => {
114                        let (inner_kind, inner) = check_type(ty);
115                        (kind, Some(inner_kind), inner.to_token_stream().to_string())
116                    }
117                    None => (kind, None, inner.to_token_stream().to_string()),
118                }
119            }
120        };
121        UserFunctionAttr {
122            name: sig.ident.to_string(),
123            async_: sig.asyncness.is_some(),
124            write: sig.inputs.iter().any(arg_is_write),
125            context: sig.inputs.iter().any(arg_is_context),
126            retract: last_arg_is_retract(sig),
127            args_option: sig.inputs.iter().map(arg_is_option).collect(),
128            first_mut_ref_arg: first_mut_ref_arg(sig),
129            return_type_kind,
130            iterator_item_kind,
131            core_return_type,
132            generic: sig.generics.params.len(),
133            return_type_span: sig.output.span(),
134        }
135    }
136}
137
138impl Parse for AggregateImpl {
139    fn parse(input: ParseStream<'_>) -> Result<Self> {
140        let itemimpl: syn::ItemImpl = input.parse()?;
141        let parse_function = |name: &str| {
142            itemimpl.items.iter().find_map(|item| match item {
143                syn::ImplItem::Fn(syn::ImplItemFn { sig, .. }) if sig.ident == name => {
144                    Some(UserFunctionAttr::from(sig))
145                }
146                _ => None,
147            })
148        };
149        let self_path = itemimpl.self_ty.to_token_stream().to_string();
150        let struct_name = match self_path.split_once('<') {
151            Some((path, _)) => path.trim().into(), // remove generic parameters
152            None => self_path,
153        };
154        Ok(AggregateImpl {
155            struct_name,
156            accumulate: parse_function("accumulate").expect("expect accumulate function"),
157            retract: parse_function("retract"),
158            merge: parse_function("merge"),
159            finalize: parse_function("finalize"),
160            create_state: parse_function("create_state"),
161            encode_state: parse_function("encode_state"),
162            decode_state: parse_function("decode_state"),
163        })
164    }
165}
166
167impl Parse for AggregateFnOrImpl {
168    fn parse(input: ParseStream<'_>) -> Result<Self> {
169        // consume attributes
170        let _ = input.call(syn::Attribute::parse_outer)?;
171        if input.peek(Token![impl]) {
172            Ok(AggregateFnOrImpl::Impl(input.parse()?))
173        } else {
174            Ok(AggregateFnOrImpl::Fn(input.parse()?))
175        }
176    }
177}
178
179/// Check if the argument is `&mut impl Write`.
180fn arg_is_write(arg: &syn::FnArg) -> bool {
181    let syn::FnArg::Typed(arg) = arg else {
182        return false;
183    };
184    let syn::Type::Reference(syn::TypeReference { elem, .. }) = arg.ty.as_ref() else {
185        return false;
186    };
187    let syn::Type::ImplTrait(syn::TypeImplTrait { bounds, .. }) = elem.as_ref() else {
188        return false;
189    };
190    let Some(syn::TypeParamBound::Trait(syn::TraitBound { path, .. })) = bounds.first() else {
191        return false;
192    };
193    let Some(seg) = path.segments.last() else {
194        return false;
195    };
196    seg.ident == "Write"
197}
198
199/// Check if the argument is `&Context`.
200fn arg_is_context(arg: &syn::FnArg) -> bool {
201    let syn::FnArg::Typed(arg) = arg else {
202        return false;
203    };
204    let syn::Type::Reference(syn::TypeReference { elem, .. }) = arg.ty.as_ref() else {
205        return false;
206    };
207    let syn::Type::Path(path) = elem.as_ref() else {
208        return false;
209    };
210    let Some(seg) = path.path.segments.last() else {
211        return false;
212    };
213    seg.ident == "Context"
214}
215
216/// Check if the last argument is `retract: bool`.
217fn last_arg_is_retract(sig: &syn::Signature) -> bool {
218    let Some(syn::FnArg::Typed(arg)) = sig.inputs.last() else {
219        return false;
220    };
221    let syn::Pat::Ident(pat) = &*arg.pat else {
222        return false;
223    };
224    pat.ident.to_string().contains("retract")
225}
226
227/// Check if the argument is `Option`.
228fn arg_is_option(arg: &syn::FnArg) -> bool {
229    let syn::FnArg::Typed(arg) = arg else {
230        return false;
231    };
232    let syn::Type::Path(path) = arg.ty.as_ref() else {
233        return false;
234    };
235    let Some(seg) = path.path.segments.last() else {
236        return false;
237    };
238    seg.ident == "Option"
239}
240
241/// Returns `T` if the first argument (except `self`) is `&mut T`.
242fn first_mut_ref_arg(sig: &syn::Signature) -> Option<String> {
243    let arg = match sig.inputs.first()? {
244        syn::FnArg::Typed(arg) => arg,
245        syn::FnArg::Receiver(_) => match sig.inputs.iter().nth(1)? {
246            syn::FnArg::Typed(arg) => arg,
247            _ => return None,
248        },
249    };
250    let syn::Type::Reference(syn::TypeReference {
251        elem,
252        mutability: Some(_),
253        ..
254    }) = arg.ty.as_ref()
255    else {
256        return None;
257    };
258    Some(elem.to_token_stream().to_string())
259}
260
261/// Check the return type.
262fn check_type(ty: &syn::Type) -> (ReturnTypeKind, &syn::Type) {
263    if let Some(inner) = strip_outer_type(ty, "Result") {
264        if let Some(inner) = strip_outer_type(inner, "Option") {
265            (ReturnTypeKind::ResultOption, inner)
266        } else {
267            (ReturnTypeKind::Result, inner)
268        }
269    } else if let Some(inner) = strip_outer_type(ty, "Option") {
270        (ReturnTypeKind::Option, inner)
271    } else if let Some(inner) = strip_outer_type(ty, "DatumRef") {
272        (ReturnTypeKind::Option, inner)
273    } else {
274        (ReturnTypeKind::T, ty)
275    }
276}
277
278/// Check if the type is `type_<T>` and return `T`.
279fn strip_outer_type<'a>(ty: &'a syn::Type, type_: &str) -> Option<&'a syn::Type> {
280    let syn::Type::Path(path) = ty else {
281        return None;
282    };
283    let seg = path.path.segments.last()?;
284    if seg.ident != type_ {
285        return None;
286    }
287    let syn::PathArguments::AngleBracketed(args) = &seg.arguments else {
288        return None;
289    };
290    let Some(syn::GenericArgument::Type(ty)) = args.args.first() else {
291        return None;
292    };
293    Some(ty)
294}
295
296/// Check if the type is `impl Iterator<Item = T>` and return `T`.
297fn strip_iterator(ty: &syn::Type) -> Option<&syn::Type> {
298    let syn::Type::ImplTrait(impl_trait) = ty else {
299        return None;
300    };
301    let syn::TypeParamBound::Trait(trait_bound) = impl_trait.bounds.first()? else {
302        return None;
303    };
304    let segment = trait_bound.path.segments.last().unwrap();
305    if segment.ident != "Iterator" {
306        return None;
307    }
308    let syn::PathArguments::AngleBracketed(angle_bracketed) = &segment.arguments else {
309        return None;
310    };
311    for arg in &angle_bracketed.args {
312        if let syn::GenericArgument::AssocType(b) = arg
313            && b.ident == "Item"
314        {
315            return Some(&b.ty);
316        }
317    }
318    None
319}