risingwave_expr_macro/
parse.rs1use quote::ToTokens;
18use syn::parse::{Parse, ParseStream};
19use syn::spanned::Spanned;
20use syn::{LitStr, Token};
21
22use super::*;
23
24impl Parse for FunctionAttr {
25 fn parse(input: ParseStream<'_>) -> Result<Self> {
27 let mut parsed = Self::default();
28
29 let sig = input.parse::<LitStr>()?;
30 let sig_str = sig.value();
31 let (name_args, ret) = match sig_str.split_once("->") {
32 Some((name_args, ret)) => (name_args, ret),
33 None => (sig_str.as_str(), "void"),
34 };
35 let (name, args) = name_args
36 .split_once('(')
37 .ok_or_else(|| Error::new_spanned(&sig, "expected '('"))?;
38 let args = args.trim_start().trim_end_matches([')', ' ']);
39 let (is_table_function, ret) = match ret.trim_start().strip_prefix("setof") {
40 Some(s) => (true, s),
41 None => (false, ret),
42 };
43 parsed.name = name.trim().to_owned();
44 parsed.args = if args.is_empty() {
45 vec![]
46 } else {
47 args.split(',').map(|s| s.trim().to_owned()).collect()
48 };
49 parsed.ret = ret.trim().to_owned();
50 parsed.is_table_function = is_table_function;
51
52 if input.parse::<Token![,]>().is_err() {
53 return Ok(parsed);
54 }
55
56 let metas = input.parse_terminated(syn::Meta::parse, Token![,])?;
57 for meta in metas {
58 let get_value = || {
59 let kv = meta.require_name_value()?;
60 let syn::Expr::Lit(lit) = &kv.value else {
61 return Err(Error::new(kv.value.span(), "expected literal"));
62 };
63 let syn::Lit::Str(lit) = &lit.lit else {
64 return Err(Error::new(kv.value.span(), "expected string literal"));
65 };
66 Ok(lit.value())
67 };
68 if meta.path().is_ident("batch_fn") {
69 parsed.batch_fn = Some(get_value()?);
70 } else if meta.path().is_ident("state") {
71 parsed.state = Some(get_value()?);
72 } else if meta.path().is_ident("init_state") {
73 parsed.init_state = Some(get_value()?);
74 } else if meta.path().is_ident("prebuild") {
75 parsed.prebuild = Some(get_value()?);
76 } else if meta.path().is_ident("type_infer") {
77 parsed.type_infer = Some(get_value()?);
78 } else if meta.path().is_ident("generic") {
79 parsed.generic = Some(get_value()?);
80 } else if meta.path().is_ident("volatile") {
81 parsed.volatile = true;
82 } else if meta.path().is_ident("deprecated") || meta.path().is_ident("internal") {
83 parsed.deprecated = true;
84 } else if meta.path().is_ident("rewritten") {
85 parsed.rewritten = true;
86 } else if meta.path().is_ident("append_only") {
87 parsed.append_only = true;
88 } else {
89 return Err(Error::new(
90 meta.span(),
91 format!("invalid property: {:?}", meta.path()),
92 ));
93 }
94 }
95 Ok(parsed)
96 }
97}
98
99impl Parse for UserFunctionAttr {
100 fn parse(input: ParseStream<'_>) -> Result<Self> {
101 let itemfn: syn::ItemFn = input.parse()?;
102 Ok(UserFunctionAttr::from(&itemfn.sig))
103 }
104}
105
106impl From<&syn::Signature> for UserFunctionAttr {
107 fn from(sig: &syn::Signature) -> Self {
108 let (return_type_kind, iterator_item_kind, core_return_type) = match &sig.output {
109 syn::ReturnType::Default => (ReturnTypeKind::T, None, "()".into()),
110 syn::ReturnType::Type(_, ty) => {
111 let (kind, inner) = check_type(ty);
112 match strip_iterator(inner) {
113 Some(ty) => {
114 let (inner_kind, inner) = check_type(ty);
115 (kind, Some(inner_kind), inner.to_token_stream().to_string())
116 }
117 None => (kind, None, inner.to_token_stream().to_string()),
118 }
119 }
120 };
121 UserFunctionAttr {
122 name: sig.ident.to_string(),
123 async_: sig.asyncness.is_some(),
124 write: sig.inputs.iter().any(arg_is_write),
125 context: sig.inputs.iter().any(arg_is_context),
126 retract: last_arg_is_retract(sig),
127 args_option: sig.inputs.iter().map(arg_is_option).collect(),
128 first_mut_ref_arg: first_mut_ref_arg(sig),
129 return_type_kind,
130 iterator_item_kind,
131 core_return_type,
132 generic: sig.generics.params.len(),
133 return_type_span: sig.output.span(),
134 }
135 }
136}
137
138impl Parse for AggregateImpl {
139 fn parse(input: ParseStream<'_>) -> Result<Self> {
140 let itemimpl: syn::ItemImpl = input.parse()?;
141 let parse_function = |name: &str| {
142 itemimpl.items.iter().find_map(|item| match item {
143 syn::ImplItem::Fn(syn::ImplItemFn { sig, .. }) if sig.ident == name => {
144 Some(UserFunctionAttr::from(sig))
145 }
146 _ => None,
147 })
148 };
149 let self_path = itemimpl.self_ty.to_token_stream().to_string();
150 let struct_name = match self_path.split_once('<') {
151 Some((path, _)) => path.trim().into(), None => self_path,
153 };
154 Ok(AggregateImpl {
155 struct_name,
156 accumulate: parse_function("accumulate").expect("expect accumulate function"),
157 retract: parse_function("retract"),
158 merge: parse_function("merge"),
159 finalize: parse_function("finalize"),
160 create_state: parse_function("create_state"),
161 encode_state: parse_function("encode_state"),
162 decode_state: parse_function("decode_state"),
163 })
164 }
165}
166
167impl Parse for AggregateFnOrImpl {
168 fn parse(input: ParseStream<'_>) -> Result<Self> {
169 let _ = input.call(syn::Attribute::parse_outer)?;
171 if input.peek(Token![impl]) {
172 Ok(AggregateFnOrImpl::Impl(input.parse()?))
173 } else {
174 Ok(AggregateFnOrImpl::Fn(input.parse()?))
175 }
176 }
177}
178
179fn arg_is_write(arg: &syn::FnArg) -> bool {
181 let syn::FnArg::Typed(arg) = arg else {
182 return false;
183 };
184 let syn::Type::Reference(syn::TypeReference { elem, .. }) = arg.ty.as_ref() else {
185 return false;
186 };
187 let syn::Type::ImplTrait(syn::TypeImplTrait { bounds, .. }) = elem.as_ref() else {
188 return false;
189 };
190 let Some(syn::TypeParamBound::Trait(syn::TraitBound { path, .. })) = bounds.first() else {
191 return false;
192 };
193 let Some(seg) = path.segments.last() else {
194 return false;
195 };
196 seg.ident == "Write"
197}
198
199fn arg_is_context(arg: &syn::FnArg) -> bool {
201 let syn::FnArg::Typed(arg) = arg else {
202 return false;
203 };
204 let syn::Type::Reference(syn::TypeReference { elem, .. }) = arg.ty.as_ref() else {
205 return false;
206 };
207 let syn::Type::Path(path) = elem.as_ref() else {
208 return false;
209 };
210 let Some(seg) = path.path.segments.last() else {
211 return false;
212 };
213 seg.ident == "Context"
214}
215
216fn last_arg_is_retract(sig: &syn::Signature) -> bool {
218 let Some(syn::FnArg::Typed(arg)) = sig.inputs.last() else {
219 return false;
220 };
221 let syn::Pat::Ident(pat) = &*arg.pat else {
222 return false;
223 };
224 pat.ident.to_string().contains("retract")
225}
226
227fn arg_is_option(arg: &syn::FnArg) -> bool {
229 let syn::FnArg::Typed(arg) = arg else {
230 return false;
231 };
232 let syn::Type::Path(path) = arg.ty.as_ref() else {
233 return false;
234 };
235 let Some(seg) = path.path.segments.last() else {
236 return false;
237 };
238 seg.ident == "Option"
239}
240
241fn first_mut_ref_arg(sig: &syn::Signature) -> Option<String> {
243 let arg = match sig.inputs.first()? {
244 syn::FnArg::Typed(arg) => arg,
245 syn::FnArg::Receiver(_) => match sig.inputs.iter().nth(1)? {
246 syn::FnArg::Typed(arg) => arg,
247 _ => return None,
248 },
249 };
250 let syn::Type::Reference(syn::TypeReference {
251 elem,
252 mutability: Some(_),
253 ..
254 }) = arg.ty.as_ref()
255 else {
256 return None;
257 };
258 Some(elem.to_token_stream().to_string())
259}
260
261fn check_type(ty: &syn::Type) -> (ReturnTypeKind, &syn::Type) {
263 if let Some(inner) = strip_outer_type(ty, "Result") {
264 if let Some(inner) = strip_outer_type(inner, "Option") {
265 (ReturnTypeKind::ResultOption, inner)
266 } else {
267 (ReturnTypeKind::Result, inner)
268 }
269 } else if let Some(inner) = strip_outer_type(ty, "Option") {
270 (ReturnTypeKind::Option, inner)
271 } else if let Some(inner) = strip_outer_type(ty, "DatumRef") {
272 (ReturnTypeKind::Option, inner)
273 } else {
274 (ReturnTypeKind::T, ty)
275 }
276}
277
278fn strip_outer_type<'a>(ty: &'a syn::Type, type_: &str) -> Option<&'a syn::Type> {
280 let syn::Type::Path(path) = ty else {
281 return None;
282 };
283 let seg = path.path.segments.last()?;
284 if seg.ident != type_ {
285 return None;
286 }
287 let syn::PathArguments::AngleBracketed(args) = &seg.arguments else {
288 return None;
289 };
290 let Some(syn::GenericArgument::Type(ty)) = args.args.first() else {
291 return None;
292 };
293 Some(ty)
294}
295
296fn strip_iterator(ty: &syn::Type) -> Option<&syn::Type> {
298 let syn::Type::ImplTrait(impl_trait) = ty else {
299 return None;
300 };
301 let syn::TypeParamBound::Trait(trait_bound) = impl_trait.bounds.first()? else {
302 return None;
303 };
304 let segment = trait_bound.path.segments.last().unwrap();
305 if segment.ident != "Iterator" {
306 return None;
307 }
308 let syn::PathArguments::AngleBracketed(angle_bracketed) = &segment.arguments else {
309 return None;
310 };
311 for arg in &angle_bracketed.args {
312 if let syn::GenericArgument::AssocType(b) = arg
313 && b.ident == "Item"
314 {
315 return Some(&b.ty);
316 }
317 }
318 None
319}