risingwave_expr_macro/
parse.rs1use proc_macro_error::abort;
18use quote::ToTokens;
19use syn::parse::{Parse, ParseStream};
20use syn::spanned::Spanned;
21use syn::{LitStr, Token, parse_quote};
22
23use super::*;
24
25impl Parse for FunctionAttr {
26 fn parse(input: ParseStream<'_>) -> Result<Self> {
28 let mut parsed = Self::default();
29
30 let sig = input.parse::<LitStr>()?;
31 let sig_str = sig.value();
32 let (name_args, ret) = match sig_str.split_once("->") {
33 Some((name_args, ret)) => (name_args, ret),
34 None => (sig_str.as_str(), "void"),
35 };
36 let (name, args) = name_args
37 .split_once('(')
38 .ok_or_else(|| Error::new_spanned(&sig, "expected '('"))?;
39 let args = args.trim_start().trim_end_matches([')', ' ']);
40 let (is_table_function, ret) = match ret.trim_start().strip_prefix("setof") {
41 Some(s) => (true, s),
42 None => (false, ret),
43 };
44 parsed.name = name.trim().to_owned();
45 parsed.args = if args.is_empty() {
46 vec![]
47 } else {
48 args.split(',').map(|s| s.trim().to_owned()).collect()
49 };
50 parsed.ret = ret.trim().to_owned();
51 parsed.is_table_function = is_table_function;
52
53 if input.parse::<Token![,]>().is_err() {
54 return Ok(parsed);
55 }
56
57 let metas = input.parse_terminated(syn::Meta::parse, Token![,])?;
58 for meta in metas {
59 let get_value = || {
60 let kv = meta.require_name_value()?;
61 let syn::Expr::Lit(lit) = &kv.value else {
62 return Err(Error::new(kv.value.span(), "expected literal"));
63 };
64 let syn::Lit::Str(lit) = &lit.lit else {
65 return Err(Error::new(kv.value.span(), "expected string literal"));
66 };
67 Ok(lit.value())
68 };
69 if meta.path().is_ident("batch_fn") {
70 parsed.batch_fn = Some(get_value()?);
71 } else if meta.path().is_ident("state") {
72 parsed.state = Some(get_value()?);
73 } else if meta.path().is_ident("init_state") {
74 parsed.init_state = Some(get_value()?);
75 } else if meta.path().is_ident("prebuild") {
76 parsed.prebuild = Some(get_value()?);
77 } else if meta.path().is_ident("type_infer") {
78 parsed.type_infer = Some(get_value()?);
79 } else if meta.path().is_ident("generic") {
80 parsed.generic = Some(get_value()?);
81 } else if meta.path().is_ident("volatile") {
82 parsed.volatile = true;
83 } else if meta.path().is_ident("deprecated") || meta.path().is_ident("internal") {
84 parsed.deprecated = true;
85 } else if meta.path().is_ident("rewritten") {
86 parsed.rewritten = true;
87 } else if meta.path().is_ident("append_only") {
88 parsed.append_only = true;
89 } else {
90 return Err(Error::new(
91 meta.span(),
92 format!("invalid property: {:?}", meta.path()),
93 ));
94 }
95 }
96 Ok(parsed)
97 }
98}
99
100impl Parse for UserFunctionAttr {
101 fn parse(input: ParseStream<'_>) -> Result<Self> {
102 let itemfn: syn::ItemFn = input.parse()?;
103 Ok(UserFunctionAttr::from(&itemfn.sig))
104 }
105}
106
107impl From<&syn::Signature> for UserFunctionAttr {
108 fn from(sig: &syn::Signature) -> Self {
109 let (return_type_kind, iterator_item_kind, core_return_type) = match &sig.output {
110 syn::ReturnType::Default => (ReturnTypeKind::T, None, "()".into()),
111 syn::ReturnType::Type(_, ty) => {
112 let (kind, inner) = check_type(ty);
113 match strip_iterator(inner) {
114 Some(ty) => {
115 let (inner_kind, inner) = check_type(ty);
116 (kind, Some(inner_kind), inner.to_token_stream().to_string())
117 }
118 None => (kind, None, inner.to_token_stream().to_string()),
119 }
120 }
121 };
122 UserFunctionAttr {
123 name: sig.ident.to_string(),
124 async_: sig.asyncness.is_some(),
125 writer_type_kind: sig.inputs.last().and_then(arg_writer_type),
126 context: sig.inputs.iter().any(arg_is_context),
127 retract: last_arg_is_retract(sig),
128 args_option: sig.inputs.iter().map(arg_is_option).collect(),
129 first_mut_ref_arg: first_mut_ref_arg(sig),
130 return_type_kind,
131 iterator_item_kind,
132 core_return_type,
133 generic: sig.generics.params.len(),
134 return_type_span: sig.output.span(),
135 }
136 }
137}
138
139impl Parse for AggregateImpl {
140 fn parse(input: ParseStream<'_>) -> Result<Self> {
141 let itemimpl: syn::ItemImpl = input.parse()?;
142 let parse_function = |name: &str| {
143 itemimpl.items.iter().find_map(|item| match item {
144 syn::ImplItem::Fn(syn::ImplItemFn { sig, .. }) if sig.ident == name => {
145 Some(UserFunctionAttr::from(sig))
146 }
147 _ => None,
148 })
149 };
150 let self_path = itemimpl.self_ty.to_token_stream().to_string();
151 let struct_name = match self_path.split_once('<') {
152 Some((path, _)) => path.trim().into(), None => self_path,
154 };
155 Ok(AggregateImpl {
156 struct_name,
157 accumulate: parse_function("accumulate").expect("expect accumulate function"),
158 retract: parse_function("retract"),
159 merge: parse_function("merge"),
160 finalize: parse_function("finalize"),
161 create_state: parse_function("create_state"),
162 encode_state: parse_function("encode_state"),
163 decode_state: parse_function("decode_state"),
164 })
165 }
166}
167
168impl Parse for AggregateFnOrImpl {
169 fn parse(input: ParseStream<'_>) -> Result<Self> {
170 let _ = input.call(syn::Attribute::parse_outer)?;
172 if input.peek(Token![impl]) {
173 Ok(AggregateFnOrImpl::Impl(input.parse()?))
174 } else {
175 Ok(AggregateFnOrImpl::Fn(input.parse()?))
176 }
177 }
178}
179
180fn arg_writer_type(arg: &syn::FnArg) -> Option<WriterTypeKind> {
182 let syn::FnArg::Typed(arg) = arg else {
183 return None;
184 };
185 let syn::Type::Reference(syn::TypeReference { elem, .. }) = arg.ty.as_ref() else {
186 return None;
187 };
188 let elem = elem.as_ref();
189 if elem == &parse_quote!(impl std::fmt::Write) {
190 Some(WriterTypeKind::FmtWrite)
191 } else if elem == &parse_quote!(impl std::io::Write) {
192 Some(WriterTypeKind::IoWrite)
193 } else if elem == &parse_quote!(impl Write) {
194 abort! { elem, "use of ambiguous `Write` trait.";
195 note = "`function!` macro can only recognize fully-qualified type.";
196 help = "Please use `std::fmt::Write` or `std::io::Write` instead."
197 };
198 } else if elem == &parse_quote!(jsonbb::Builder) {
199 Some(WriterTypeKind::JsonbbBuilder)
200 } else if elem == &parse_quote!(Builder) {
201 abort! { elem, "use of ambiguous `Builder` type.";
202 note = "`function!` macro can only recognize fully-qualified type.";
203 help = "Please use `jsonbb::Builder` instead."
204 };
205 } else if elem == &parse_quote!(impl risingwave_common::array::ListWrite) {
206 Some(WriterTypeKind::ListWrite)
207 } else {
208 None
209 }
210}
211
212fn arg_is_context(arg: &syn::FnArg) -> bool {
214 let syn::FnArg::Typed(arg) = arg else {
215 return false;
216 };
217 let syn::Type::Reference(syn::TypeReference { elem, .. }) = arg.ty.as_ref() else {
218 return false;
219 };
220 let syn::Type::Path(path) = elem.as_ref() else {
221 return false;
222 };
223 let Some(seg) = path.path.segments.last() else {
224 return false;
225 };
226 seg.ident == "Context"
227}
228
229fn last_arg_is_retract(sig: &syn::Signature) -> bool {
231 let Some(syn::FnArg::Typed(arg)) = sig.inputs.last() else {
232 return false;
233 };
234 let syn::Pat::Ident(pat) = &*arg.pat else {
235 return false;
236 };
237 pat.ident.to_string().contains("retract")
238}
239
240fn arg_is_option(arg: &syn::FnArg) -> bool {
242 let syn::FnArg::Typed(arg) = arg else {
243 return false;
244 };
245 let syn::Type::Path(path) = arg.ty.as_ref() else {
246 return false;
247 };
248 let Some(seg) = path.path.segments.last() else {
249 return false;
250 };
251 seg.ident == "Option"
252}
253
254fn first_mut_ref_arg(sig: &syn::Signature) -> Option<String> {
256 let arg = match sig.inputs.first()? {
257 syn::FnArg::Typed(arg) => arg,
258 syn::FnArg::Receiver(_) => match sig.inputs.iter().nth(1)? {
259 syn::FnArg::Typed(arg) => arg,
260 _ => return None,
261 },
262 };
263 let syn::Type::Reference(syn::TypeReference {
264 elem,
265 mutability: Some(_),
266 ..
267 }) = arg.ty.as_ref()
268 else {
269 return None;
270 };
271 Some(elem.to_token_stream().to_string())
272}
273
274fn check_type(ty: &syn::Type) -> (ReturnTypeKind, &syn::Type) {
276 if let Some(inner) = strip_outer_type(ty, "Result") {
277 if let Some(inner) = strip_outer_type(inner, "Option") {
278 (ReturnTypeKind::ResultOption, inner)
279 } else {
280 (ReturnTypeKind::Result, inner)
281 }
282 } else if let Some(inner) = strip_outer_type(ty, "Option") {
283 (ReturnTypeKind::Option, inner)
284 } else if let Some(inner) = strip_outer_type(ty, "DatumRef") {
285 (ReturnTypeKind::Option, inner)
286 } else {
287 (ReturnTypeKind::T, ty)
288 }
289}
290
291fn strip_outer_type<'a>(ty: &'a syn::Type, type_: &str) -> Option<&'a syn::Type> {
293 let syn::Type::Path(path) = ty else {
294 return None;
295 };
296 let seg = path.path.segments.last()?;
297 if seg.ident != type_ {
298 return None;
299 }
300 let syn::PathArguments::AngleBracketed(args) = &seg.arguments else {
301 return None;
302 };
303 let Some(syn::GenericArgument::Type(ty)) = args.args.first() else {
304 return None;
305 };
306 Some(ty)
307}
308
309fn strip_iterator(ty: &syn::Type) -> Option<&syn::Type> {
311 let syn::Type::ImplTrait(impl_trait) = ty else {
312 return None;
313 };
314 let syn::TypeParamBound::Trait(trait_bound) = impl_trait.bounds.first()? else {
315 return None;
316 };
317 let segment = trait_bound.path.segments.last().unwrap();
318 if segment.ident != "Iterator" {
319 return None;
320 }
321 let syn::PathArguments::AngleBracketed(angle_bracketed) = &segment.arguments else {
322 return None;
323 };
324 for arg in &angle_bracketed.args {
325 if let syn::GenericArgument::AssocType(b) = arg
326 && b.ident == "Item"
327 {
328 return Some(&b.ty);
329 }
330 }
331 None
332}