risingwave_expr_macro/lib.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::vec;
16
17use context::{CaptureContextAttr, DefineContextAttr, generate_captured_function};
18use proc_macro::TokenStream;
19use proc_macro2::TokenStream as TokenStream2;
20use syn::{Error, ItemFn, Result};
21
22mod context;
23mod r#gen;
24mod parse;
25mod types;
26mod utils;
27
28/// Defining the RisingWave SQL function from a Rust function.
29///
30/// [Online version of this doc.](https://risingwavelabs.github.io/risingwave/rustdoc/risingwave_expr_macro/attr.function.html)
31///
32/// # Table of Contents
33///
34/// - [SQL Function Signature](#sql-function-signature)
35/// - [Multiple Function Definitions](#multiple-function-definitions)
36/// - [Type Expansion](#type-expansion)
37/// - [Automatic Type Inference](#automatic-type-inference)
38/// - [Custom Type Inference Function](#custom-type-inference-function)
39/// - [Rust Function Signature](#rust-function-signature)
40/// - [Nullable Arguments](#nullable-arguments)
41/// - [Return Value](#return-value)
42/// - [Variadic Function](#variadic-function)
43/// - [Optimization](#optimization)
44/// - [Functions Returning Strings](#functions-returning-strings)
45/// - [Preprocessing Constant Arguments](#preprocessing-constant-arguments)
46/// - [Context](#context)
47/// - [Async Function](#async-function)
48/// - [Table Function](#table-function)
49/// - [Registration and Invocation](#registration-and-invocation)
50/// - [Appendix: Type Matrix](#appendix-type-matrix)
51///
52/// The following example demonstrates a simple usage:
53///
54/// ```ignore
55/// #[function("add(int32, int32) -> int32")]
56/// fn add(x: i32, y: i32) -> i32 {
57/// x + y
58/// }
59/// ```
60///
61/// # SQL Function Signature
62///
63/// Each function must have a signature, specified in the `function("...")` part of the macro
64/// invocation. The signature follows this pattern:
65///
66/// ```text
67/// name ( [arg_types],* [...] ) [ -> [setof] return_type ]
68/// ```
69///
70/// Where `name` is the function name in `snake_case`, which must match the function name (in `UPPER_CASE`) defined
71/// in `proto/expr.proto`.
72///
73/// `arg_types` is a comma-separated list of argument types. The allowed data types are listed in
74/// in the `name` column of the appendix's [type matrix]. Wildcards or `auto` can also be used, as
75/// explained below. If the function is variadic, the last argument can be denoted as `...`.
76///
77/// When `setof` appears before the return type, this indicates that the function is a set-returning
78/// function (table function), meaning it can return multiple values instead of just one. For more
79/// details, see the section on table functions.
80///
81/// If no return type is specified, the function returns `void`. However, the void type is not
82/// supported in our type system, so it now returns a null value of type int.
83///
84/// ## Multiple Function Definitions
85///
86/// Multiple `#[function]` macros can be applied to a single generic Rust function to define
87/// multiple SQL functions of different types. For example:
88///
89/// ```ignore
90/// #[function("add(int16, int16) -> int16")]
91/// #[function("add(int32, int32) -> int32")]
92/// #[function("add(int64, int64) -> int64")]
93/// fn add<T: Add>(x: T, y: T) -> T {
94/// x + y
95/// }
96/// ```
97///
98/// ## Type Expansion with `*`
99///
100/// Types can be automatically expanded to multiple types using wildcards. Here are some examples:
101///
102/// - `*`: expands to all types.
103/// - `*int`: expands to int16, int32, int64.
104/// - `*float`: expands to float32, float64.
105///
106/// For instance, `#[function("cast(varchar) -> *int")]` will be expanded to the following three
107/// functions:
108///
109/// ```ignore
110/// #[function("cast(varchar) -> int16")]
111/// #[function("cast(varchar) -> int32")]
112/// #[function("cast(varchar) -> int64")]
113/// ```
114///
115/// Please note the difference between `*` and `any`: `*` will generate a function for each type,
116/// whereas `any` will only generate one function with a dynamic data type `Scalar`.
117/// This is similar to `impl T` and `dyn T` in Rust. The performance of using `*` would be much better than `any`.
118/// But we do not always prefer `*` due to better performance. In some cases, using `any` is more convenient.
119/// For example, in array functions, the element type of `ListValue` is `Scalar(Ref)Impl`.
120/// It is unnecessary to convert it from/into various `T`.
121///
122/// ## Automatic Type Inference with `auto`
123///
124/// Correspondingly, the return type can be denoted as `auto` to be automatically inferred based on
125/// the input types. It will be inferred as the _smallest type_ that can accommodate all input types.
126///
127/// For example, `#[function("add(*int, *int) -> auto")]` will be expanded to:
128///
129/// ```ignore
130/// #[function("add(int16, int16) -> int16")]
131/// #[function("add(int16, int32) -> int32")]
132/// #[function("add(int16, int64) -> int64")]
133/// #[function("add(int32, int16) -> int32")]
134/// ...
135/// ```
136///
137/// Especially when there is only one input argument, `auto` will be inferred as the type of that
138/// argument. For example, `#[function("neg(*int) -> auto")]` will be expanded to:
139///
140/// ```ignore
141/// #[function("neg(int16) -> int16")]
142/// #[function("neg(int32) -> int32")]
143/// #[function("neg(int64) -> int64")]
144/// ```
145///
146/// ## Custom Type Inference Function with `type_infer`
147///
148/// A few functions might have a return type that dynamically changes based on the input argument
149/// types, such as `unnest`. This is mainly for composite types like `anyarray`, `struct`, and `anymap`.
150///
151/// In such cases, the `type_infer` option can be used to specify a function to infer the return
152/// type based on the input argument types. Its function signature is
153///
154/// ```ignore
155/// fn(&[DataType]) -> Result<DataType>
156/// ```
157///
158/// For example:
159///
160/// ```ignore
161/// #[function(
162/// "unnest(anyarray) -> setof any",
163/// type_infer = "|args| Ok(args[0].unnest_list())"
164/// )]
165/// ```
166///
167/// This type inference function will be invoked at the frontend (`infer_type_with_sigmap`).
168///
169/// # Rust Function Signature
170///
171/// The `#[function]` macro can handle various types of Rust functions.
172///
173/// Each argument corresponds to the *reference type* in the [type matrix].
174///
175/// The return value type can be the *reference type* or *owned type* in the [type matrix].
176///
177/// For instance:
178///
179/// ```ignore
180/// #[function("trim_array(anyarray, int32) -> anyarray")]
181/// fn trim_array(array: ListRef<'_>, n: i32) -> ListValue {...}
182/// ```
183///
184/// ## Nullable Arguments
185///
186/// The functions above will only be called when all arguments are not null.
187/// It will return null if any argument is null.
188/// If null arguments need to be considered, the `Option` type can be used:
189///
190/// ```ignore
191/// #[function("trim_array(anyarray, int32) -> anyarray")]
192/// fn trim_array(array: ListRef<'_>, n: Option<i32>) -> ListValue {...}
193/// ```
194///
195/// This function will be called when `n` is null, but not when `array` is null.
196///
197/// ## Return `NULL`s and Errors
198///
199/// Similarly, the return value type can be one of the following:
200///
201/// - `T`: Indicates that a non-null value is always returned (for non-null inputs), and errors will not occur.
202/// - `Option<T>`: Indicates that a null value may be returned, but errors will not occur.
203/// - `Result<T>`: Indicates that an error may occur, but a null value will not be returned.
204/// - `Result<Option<T>>`: Indicates that a null value may be returned, and an error may also occur.
205///
206/// ## Optimization
207///
208/// When all input and output types of the function are *primitive type* (refer to the [type
209/// matrix]) and do not contain any Option or Result, the `#[function]` macro will automatically
210/// generate SIMD vectorized execution code.
211///
212/// Therefore, try to avoid returning `Option` and `Result` whenever possible.
213///
214/// ## Variadic Function
215///
216/// Variadic functions accept a `impl Row` input to represent tailing arguments.
217/// For example:
218///
219/// ```ignore
220/// #[function("concat_ws(varchar, ...) -> varchar")]
221/// fn concat_ws(sep: &str, vals: impl Row) -> Option<Box<str>> {
222/// let mut string_iter = vals.iter().flatten();
223/// // ...
224/// }
225/// ```
226///
227/// See `risingwave_common::row::Row` for more details.
228///
229/// ## Functions Returning Strings
230///
231/// For functions that return varchar types, you can also use the writer style function signature to
232/// avoid memory copying and dynamic memory allocation:
233///
234/// ```ignore
235/// #[function("trim(varchar) -> varchar")]
236/// fn trim(s: &str, writer: &mut impl Write) {
237/// writer.write_str(s.trim()).unwrap();
238/// }
239/// ```
240///
241/// If errors may be returned, then the return value should be `Result<()>`:
242///
243/// ```ignore
244/// #[function("trim(varchar) -> varchar")]
245/// fn trim(s: &str, writer: &mut impl Write) -> Result<()> {
246/// writer.write_str(s.trim()).unwrap();
247/// Ok(())
248/// }
249/// ```
250///
251/// If null values may be returned, then the return value should be `Option<()>`:
252///
253/// ```ignore
254/// #[function("trim(varchar) -> varchar")]
255/// fn trim(s: &str, writer: &mut impl Write) -> Option<()> {
256/// if s.is_empty() {
257/// None
258/// } else {
259/// writer.write_str(s.trim()).unwrap();
260/// Some(())
261/// }
262/// }
263/// ```
264///
265/// ## Preprocessing Constant Arguments
266///
267/// When some input arguments of the function are constants, they can be preprocessed to avoid
268/// calculations every time the function is called.
269///
270/// A classic use case is regular expression matching:
271///
272/// ```ignore
273/// #[function(
274/// "regexp_match(varchar, varchar, varchar) -> varchar[]",
275/// prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
276/// )]
277/// fn regexp_match(text: &str, regex: &RegexpContext) -> ListValue {
278/// regex.captures(text).collect()
279/// }
280/// ```
281///
282/// The `prebuild` argument can be specified, and its value is a Rust expression `Type::method(...)`
283/// used to construct a new variable of `Type` from the input arguments of the function.
284/// Here `$1`, `$2` represent the second and third arguments of the function (indexed from 0),
285/// and their types are `&str`. In the Rust function signature, these positions of parameters will
286/// be omitted, replaced by an extra new variable at the end.
287///
288/// This macro generates two versions of the function. If all the input parameters that `prebuild`
289/// depends on are constants, it will precompute them during the build function. Otherwise, it will
290/// compute them for each input row during evaluation. This way, we support both constant and variable
291/// inputs while optimizing performance for constant inputs.
292///
293/// ## Context
294///
295/// If a function needs to obtain type information at runtime, you can add an `&Context` parameter to
296/// the function signature. For example:
297///
298/// ```ignore
299/// #[function("foo(int32) -> int64")]
300/// fn foo(a: i32, ctx: &Context) -> i64 {
301/// assert_eq!(ctx.arg_types[0], DataType::Int32);
302/// assert_eq!(ctx.return_type, DataType::Int64);
303/// // ...
304/// }
305/// ```
306///
307/// ## Async Function
308///
309/// Functions can be asynchronous.
310///
311/// ```ignore
312/// #[function("pg_sleep(float64)")]
313/// async fn pg_sleep(second: F64) {
314/// tokio::time::sleep(Duration::from_secs_f64(second.0)).await;
315/// }
316/// ```
317///
318/// Asynchronous functions will be evaluated on rows sequentially.
319///
320/// # Table Function
321///
322/// A table function is a special kind of function that can return multiple values instead of just
323/// one. Its function signature must include the `setof` keyword, and the Rust function should
324/// return an iterator of the form `impl Iterator<Item = T>` or its derived types.
325///
326/// For example:
327/// ```ignore
328/// #[function("generate_series(int32, int32) -> setof int32")]
329/// fn generate_series(start: i32, stop: i32) -> impl Iterator<Item = i32> {
330/// start..=stop
331/// }
332/// ```
333///
334/// Likewise, the return value `Iterator` can include `Option` or `Result` either internally or
335/// externally. For instance:
336///
337/// - `impl Iterator<Item = Result<T>>`
338/// - `Result<impl Iterator<Item = T>>`
339/// - `Result<impl Iterator<Item = Result<Option<T>>>>`
340///
341/// Currently, table function arguments do not support the `Option` type. That is, the function will
342/// only be invoked when all arguments are not null.
343///
344/// # Registration and Invocation
345///
346/// Every function defined by `#[function]` is automatically registered in the global function
347/// table.
348///
349/// You can build expressions through the following functions:
350///
351/// ```ignore
352/// // scalar functions
353/// risingwave_expr::expr::build(...) -> BoxedExpression
354/// risingwave_expr::expr::build_from_prost(...) -> BoxedExpression
355/// // table functions
356/// risingwave_expr::table_function::build(...) -> BoxedTableFunction
357/// risingwave_expr::table_function::build_from_prost(...) -> BoxedTableFunction
358/// ```
359///
360/// Or get their metadata through the following functions:
361///
362/// ```ignore
363/// // scalar functions
364/// risingwave_expr::sig::func::FUNC_SIG_MAP::get(...)
365/// // table functions
366/// risingwave_expr::sig::table_function::FUNC_SIG_MAP::get(...)
367/// ```
368///
369/// # Appendix: Type Matrix
370///
371/// ## Base Types
372///
373/// | name | SQL type | owned type | reference type | primitive? |
374/// | ----------- | ------------------ | ------------- | ------------------ | ---------- |
375/// | boolean | `boolean` | `bool` | `bool` | yes |
376/// | int2 | `smallint` | `i16` | `i16` | yes |
377/// | int4 | `integer` | `i32` | `i32` | yes |
378/// | int8 | `bigint` | `i64` | `i64` | yes |
379/// | int256 | `rw_int256` | `Int256` | `Int256Ref<'_>` | no |
380/// | float4 | `real` | `F32` | `F32` | yes |
381/// | float8 | `double precision` | `F64` | `F64` | yes |
382/// | decimal | `numeric` | `Decimal` | `Decimal` | yes |
383/// | serial | `serial` | `Serial` | `Serial` | yes |
384/// | date | `date` | `Date` | `Date` | yes |
385/// | time | `time` | `Time` | `Time` | yes |
386/// | timestamp | `timestamp` | `Timestamp` | `Timestamp` | yes |
387/// | timestamptz | `timestamptz` | `Timestamptz` | `Timestamptz` | yes |
388/// | interval | `interval` | `Interval` | `Interval` | yes |
389/// | varchar | `varchar` | `Box<str>` | `&str` | no |
390/// | bytea | `bytea` | `Box<[u8]>` | `&[u8]` | no |
391/// | jsonb | `jsonb` | `JsonbVal` | `JsonbRef<'_>` | no |
392/// | any | `any` | `ScalarImpl` | `ScalarRefImpl<'_>`| no |
393///
394/// ## Composite Types
395///
396/// | name | SQL type | owned type | reference type |
397/// | ---------------------- | -------------------- | ------------- | ------------------ |
398/// | anyarray | `any[]` | `ListValue` | `ListRef<'_>` |
399/// | struct | `record` | `StructValue` | `StructRef<'_>` |
400/// | T[^1][] | `T[]` | `ListValue` | `ListRef<'_>` |
401/// | struct<`name_T`[^1], ..> | `struct<name T, ..>` | `(T, ..)` | `(&T, ..)` |
402///
403/// [^1]: `T` could be any base type
404///
405/// [type matrix]: #appendix-type-matrix
406#[proc_macro_attribute]
407pub fn function(attr: TokenStream, item: TokenStream) -> TokenStream {
408 fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
409 let fn_attr: FunctionAttr = syn::parse(attr)?;
410 let user_fn: UserFunctionAttr = syn::parse(item.clone())?;
411
412 let mut tokens: TokenStream2 = item.into();
413 for attr in fn_attr.expand() {
414 tokens.extend(attr.generate_function_descriptor(&user_fn, false)?);
415 }
416 Ok(tokens)
417 }
418 match inner(attr, item) {
419 Ok(tokens) => tokens.into(),
420 Err(e) => e.to_compile_error().into(),
421 }
422}
423
424/// Different from `#[function]`, which implements the `Expression` trait for a rust scalar function,
425/// `#[build_function]` is used when you already implemented `Expression` manually.
426///
427/// The expected input is a "build" function:
428/// ```ignore
429/// fn(data_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression>
430/// ```
431///
432/// It generates the function descriptor using the "build" function and
433/// registers the description to the `FUNC_SIG_MAP`.
434#[proc_macro_attribute]
435pub fn build_function(attr: TokenStream, item: TokenStream) -> TokenStream {
436 fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
437 let fn_attr: FunctionAttr = syn::parse(attr)?;
438 let user_fn: UserFunctionAttr = syn::parse(item.clone())?;
439
440 let mut tokens: TokenStream2 = item.into();
441 for attr in fn_attr.expand() {
442 tokens.extend(attr.generate_function_descriptor(&user_fn, true)?);
443 }
444 Ok(tokens)
445 }
446 match inner(attr, item) {
447 Ok(tokens) => tokens.into(),
448 Err(e) => e.to_compile_error().into(),
449 }
450}
451
452#[proc_macro_attribute]
453pub fn aggregate(attr: TokenStream, item: TokenStream) -> TokenStream {
454 fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
455 let fn_attr: FunctionAttr = syn::parse(attr)?;
456 let user_fn: AggregateFnOrImpl = syn::parse(item.clone())?;
457
458 let mut tokens: TokenStream2 = item.into();
459 for attr in fn_attr.expand() {
460 tokens.extend(attr.generate_aggregate_descriptor(&user_fn, false)?);
461 }
462 Ok(tokens)
463 }
464 match inner(attr, item) {
465 Ok(tokens) => tokens.into(),
466 Err(e) => e.to_compile_error().into(),
467 }
468}
469
470#[proc_macro_attribute]
471pub fn build_aggregate(attr: TokenStream, item: TokenStream) -> TokenStream {
472 fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
473 let fn_attr: FunctionAttr = syn::parse(attr)?;
474 let user_fn: AggregateFnOrImpl = syn::parse(item.clone())?;
475
476 let mut tokens: TokenStream2 = item.into();
477 for attr in fn_attr.expand() {
478 tokens.extend(attr.generate_aggregate_descriptor(&user_fn, true)?);
479 }
480 Ok(tokens)
481 }
482 match inner(attr, item) {
483 Ok(tokens) => tokens.into(),
484 Err(e) => e.to_compile_error().into(),
485 }
486}
487
488#[derive(Debug, Clone, Default)]
489struct FunctionAttr {
490 /// Function name
491 name: String,
492 /// Input argument types
493 args: Vec<String>,
494 /// Return type
495 ret: String,
496 /// Whether it is a table function
497 is_table_function: bool,
498 /// Whether it is an append-only aggregate function
499 append_only: bool,
500 /// Optional function for batch evaluation.
501 batch_fn: Option<String>,
502 /// State type for aggregate function.
503 /// If not specified, it will be the same as return type.
504 state: Option<String>,
505 /// Initial state value for aggregate function.
506 /// If not specified, it will be NULL.
507 init_state: Option<String>,
508 /// Prebuild function for arguments.
509 /// This could be any Rust expression.
510 prebuild: Option<String>,
511 /// Type inference function.
512 type_infer: Option<String>,
513 /// Generic type.
514 generic: Option<String>,
515 /// Whether the function is volatile.
516 volatile: bool,
517 /// If true, the function is unavailable on the frontend.
518 deprecated: bool,
519 /// If true, the function is not implemented on the backend, but its signature is defined.
520 rewritten: bool,
521}
522
523/// Attributes from function signature `fn(..)`
524#[derive(Debug, Clone)]
525struct UserFunctionAttr {
526 /// Function name
527 name: String,
528 /// Whether the function is async.
529 async_: bool,
530 /// Whether contains argument `&Context`.
531 context: bool,
532 /// Whether contains argument `&mut impl Write`.
533 write: bool,
534 /// Whether the last argument type is `retract: bool`.
535 retract: bool,
536 /// Whether each argument type is `Option<T>`.
537 args_option: Vec<bool>,
538 /// If the first argument type is `&mut T`, then `Some(T)`.
539 first_mut_ref_arg: Option<String>,
540 /// The return type kind.
541 return_type_kind: ReturnTypeKind,
542 /// The kind of inner type `T` in `impl Iterator<Item = T>`
543 iterator_item_kind: Option<ReturnTypeKind>,
544 /// The core return type without `Option` or `Result`.
545 core_return_type: String,
546 /// The number of generic types.
547 generic: usize,
548 /// The span of return type.
549 return_type_span: proc_macro2::Span,
550}
551
552#[derive(Debug, Clone)]
553struct AggregateImpl {
554 struct_name: String,
555 accumulate: UserFunctionAttr,
556 retract: Option<UserFunctionAttr>,
557 #[allow(dead_code)] // TODO(wrj): add merge to trait
558 merge: Option<UserFunctionAttr>,
559 finalize: Option<UserFunctionAttr>,
560 create_state: Option<UserFunctionAttr>,
561 #[allow(dead_code)] // TODO(wrj): support encode
562 encode_state: Option<UserFunctionAttr>,
563 #[allow(dead_code)] // TODO(wrj): support decode
564 decode_state: Option<UserFunctionAttr>,
565}
566
567#[derive(Debug, Clone)]
568#[allow(clippy::large_enum_variant)]
569enum AggregateFnOrImpl {
570 /// A simple accumulate/retract function.
571 Fn(UserFunctionAttr),
572 /// A full impl block.
573 Impl(AggregateImpl),
574}
575
576impl AggregateFnOrImpl {
577 fn as_fn(&self) -> &UserFunctionAttr {
578 match self {
579 AggregateFnOrImpl::Fn(attr) => attr,
580 _ => panic!("expect fn"),
581 }
582 }
583
584 fn accumulate(&self) -> &UserFunctionAttr {
585 match self {
586 AggregateFnOrImpl::Fn(attr) => attr,
587 AggregateFnOrImpl::Impl(impl_) => &impl_.accumulate,
588 }
589 }
590
591 fn has_retract(&self) -> bool {
592 match self {
593 AggregateFnOrImpl::Fn(fn_) => fn_.retract,
594 AggregateFnOrImpl::Impl(impl_) => impl_.retract.is_some(),
595 }
596 }
597}
598
599#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
600enum ReturnTypeKind {
601 T,
602 Option,
603 Result,
604 ResultOption,
605}
606
607impl FunctionAttr {
608 /// Return a unique name that can be used as an identifier.
609 fn ident_name(&self) -> String {
610 format!("{}_{}_{}", self.name, self.args.join("_"), self.ret)
611 .replace("[]", "array")
612 .replace("...", "variadic")
613 .replace(['<', '>', ' ', ','], "_")
614 .replace("__", "_")
615 }
616}
617
618impl UserFunctionAttr {
619 /// Returns true if the function is like `fn(T1, T2, .., Tn) -> T`.
620 fn is_pure(&self) -> bool {
621 !self.async_
622 && !self.write
623 && !self.context
624 && self.args_option.iter().all(|b| !b)
625 && self.return_type_kind == ReturnTypeKind::T
626 }
627}
628
629/// Define the context variables which can be used by risingwave expressions.
630#[proc_macro]
631pub fn define_context(def: TokenStream) -> TokenStream {
632 fn inner(def: TokenStream) -> Result<TokenStream2> {
633 let attr: DefineContextAttr = syn::parse(def)?;
634 attr.r#gen()
635 }
636
637 match inner(def) {
638 Ok(tokens) => tokens.into(),
639 Err(e) => e.to_compile_error().into(),
640 }
641}
642
643/// Capture the context from the local context to the function impl.
644/// TODO: The macro will be merged to [`#[function(.., capture_context(..))]`](macro@function) later.
645///
646/// Currently, we should use the macro separately with a simple wrapper.
647#[proc_macro_attribute]
648pub fn capture_context(attr: TokenStream, item: TokenStream) -> TokenStream {
649 fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
650 let attr: CaptureContextAttr = syn::parse(attr)?;
651 let user_fn: ItemFn = syn::parse(item)?;
652
653 // Generate captured function
654 generate_captured_function(attr, user_fn)
655 }
656 match inner(attr, item) {
657 Ok(tokens) => tokens.into(),
658 Err(e) => e.to_compile_error().into(),
659 }
660}