risingwave_expr_macro/
lib.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(let_chains)]
16
17use std::vec;
18
19use context::{CaptureContextAttr, DefineContextAttr, generate_captured_function};
20use proc_macro::TokenStream;
21use proc_macro2::TokenStream as TokenStream2;
22use syn::{Error, ItemFn, Result};
23
24mod context;
25mod r#gen;
26mod parse;
27mod types;
28mod utils;
29
30/// Defining the RisingWave SQL function from a Rust function.
31///
32/// [Online version of this doc.](https://risingwavelabs.github.io/risingwave/rustdoc/risingwave_expr_macro/attr.function.html)
33///
34/// # Table of Contents
35///
36/// - [SQL Function Signature](#sql-function-signature)
37///     - [Multiple Function Definitions](#multiple-function-definitions)
38///     - [Type Expansion](#type-expansion)
39///     - [Automatic Type Inference](#automatic-type-inference)
40///     - [Custom Type Inference Function](#custom-type-inference-function)
41/// - [Rust Function Signature](#rust-function-signature)
42///     - [Nullable Arguments](#nullable-arguments)
43///     - [Return Value](#return-value)
44///     - [Variadic Function](#variadic-function)
45///     - [Optimization](#optimization)
46///     - [Functions Returning Strings](#functions-returning-strings)
47///     - [Preprocessing Constant Arguments](#preprocessing-constant-arguments)
48///     - [Context](#context)
49///     - [Async Function](#async-function)
50/// - [Table Function](#table-function)
51/// - [Registration and Invocation](#registration-and-invocation)
52/// - [Appendix: Type Matrix](#appendix-type-matrix)
53///
54/// The following example demonstrates a simple usage:
55///
56/// ```ignore
57/// #[function("add(int32, int32) -> int32")]
58/// fn add(x: i32, y: i32) -> i32 {
59///     x + y
60/// }
61/// ```
62///
63/// # SQL Function Signature
64///
65/// Each function must have a signature, specified in the `function("...")` part of the macro
66/// invocation. The signature follows this pattern:
67///
68/// ```text
69/// name ( [arg_types],* [...] ) [ -> [setof] return_type ]
70/// ```
71///
72/// Where `name` is the function name in `snake_case`, which must match the function name (in `UPPER_CASE`) defined
73/// in `proto/expr.proto`.
74///
75/// `arg_types` is a comma-separated list of argument types. The allowed data types are listed in
76/// in the `name` column of the appendix's [type matrix]. Wildcards or `auto` can also be used, as
77/// explained below. If the function is variadic, the last argument can be denoted as `...`.
78///
79/// When `setof` appears before the return type, this indicates that the function is a set-returning
80/// function (table function), meaning it can return multiple values instead of just one. For more
81/// details, see the section on table functions.
82///
83/// If no return type is specified, the function returns `void`. However, the void type is not
84/// supported in our type system, so it now returns a null value of type int.
85///
86/// ## Multiple Function Definitions
87///
88/// Multiple `#[function]` macros can be applied to a single generic Rust function to define
89/// multiple SQL functions of different types. For example:
90///
91/// ```ignore
92/// #[function("add(int16, int16) -> int16")]
93/// #[function("add(int32, int32) -> int32")]
94/// #[function("add(int64, int64) -> int64")]
95/// fn add<T: Add>(x: T, y: T) -> T {
96///     x + y
97/// }
98/// ```
99///
100/// ## Type Expansion with `*`
101///
102/// Types can be automatically expanded to multiple types using wildcards. Here are some examples:
103///
104/// - `*`: expands to all types.
105/// - `*int`: expands to int16, int32, int64.
106/// - `*float`: expands to float32, float64.
107///
108/// For instance, `#[function("cast(varchar) -> *int")]` will be expanded to the following three
109/// functions:
110///
111/// ```ignore
112/// #[function("cast(varchar) -> int16")]
113/// #[function("cast(varchar) -> int32")]
114/// #[function("cast(varchar) -> int64")]
115/// ```
116///
117/// Please note the difference between `*` and `any`: `*` will generate a function for each type,
118/// whereas `any` will only generate one function with a dynamic data type `Scalar`.
119/// This is similar to `impl T` and `dyn T` in Rust. The performance of using `*` would be much better than `any`.
120/// But we do not always prefer `*` due to better performance. In some cases, using `any` is more convenient.
121/// For example, in array functions, the element type of `ListValue` is `Scalar(Ref)Impl`.
122/// It is unnecessary to convert it from/into various `T`.
123///
124/// ## Automatic Type Inference with `auto`
125///
126/// Correspondingly, the return type can be denoted as `auto` to be automatically inferred based on
127/// the input types. It will be inferred as the _smallest type_ that can accommodate all input types.
128///
129/// For example, `#[function("add(*int, *int) -> auto")]` will be expanded to:
130///
131/// ```ignore
132/// #[function("add(int16, int16) -> int16")]
133/// #[function("add(int16, int32) -> int32")]
134/// #[function("add(int16, int64) -> int64")]
135/// #[function("add(int32, int16) -> int32")]
136/// ...
137/// ```
138///
139/// Especially when there is only one input argument, `auto` will be inferred as the type of that
140/// argument. For example, `#[function("neg(*int) -> auto")]` will be expanded to:
141///
142/// ```ignore
143/// #[function("neg(int16) -> int16")]
144/// #[function("neg(int32) -> int32")]
145/// #[function("neg(int64) -> int64")]
146/// ```
147///
148/// ## Custom Type Inference Function with `type_infer`
149///
150/// A few functions might have a return type that dynamically changes based on the input argument
151/// types, such as `unnest`. This is mainly for composite types like `anyarray`, `struct`, and `anymap`.
152///
153/// In such cases, the `type_infer` option can be used to specify a function to infer the return
154/// type based on the input argument types. Its function signature is
155///
156/// ```ignore
157/// fn(&[DataType]) -> Result<DataType>
158/// ```
159///
160/// For example:
161///
162/// ```ignore
163/// #[function(
164///     "unnest(anyarray) -> setof any",
165///     type_infer = "|args| Ok(args[0].unnest_list())"
166/// )]
167/// ```
168///
169/// This type inference function will be invoked at the frontend (`infer_type_with_sigmap`).
170///
171/// # Rust Function Signature
172///
173/// The `#[function]` macro can handle various types of Rust functions.
174///
175/// Each argument corresponds to the *reference type* in the [type matrix].
176///
177/// The return value type can be the *reference type* or *owned type* in the [type matrix].
178///
179/// For instance:
180///
181/// ```ignore
182/// #[function("trim_array(anyarray, int32) -> anyarray")]
183/// fn trim_array(array: ListRef<'_>, n: i32) -> ListValue {...}
184/// ```
185///
186/// ## Nullable Arguments
187///
188/// The functions above will only be called when all arguments are not null.
189/// It will return null if any argument is null.
190/// If null arguments need to be considered, the `Option` type can be used:
191///
192/// ```ignore
193/// #[function("trim_array(anyarray, int32) -> anyarray")]
194/// fn trim_array(array: ListRef<'_>, n: Option<i32>) -> ListValue {...}
195/// ```
196///
197/// This function will be called when `n` is null, but not when `array` is null.
198///
199/// ## Return `NULL`s and Errors
200///
201/// Similarly, the return value type can be one of the following:
202///
203/// - `T`: Indicates that a non-null value is always returned (for non-null inputs), and errors will not occur.
204/// - `Option<T>`: Indicates that a null value may be returned, but errors will not occur.
205/// - `Result<T>`: Indicates that an error may occur, but a null value will not be returned.
206/// - `Result<Option<T>>`: Indicates that a null value may be returned, and an error may also occur.
207///
208/// ## Optimization
209///
210/// When all input and output types of the function are *primitive type* (refer to the [type
211/// matrix]) and do not contain any Option or Result, the `#[function]` macro will automatically
212/// generate SIMD vectorized execution code.
213///
214/// Therefore, try to avoid returning `Option` and `Result` whenever possible.
215///
216/// ## Variadic Function
217///
218/// Variadic functions accept a `impl Row` input to represent tailing arguments.
219/// For example:
220///
221/// ```ignore
222/// #[function("concat_ws(varchar, ...) -> varchar")]
223/// fn concat_ws(sep: &str, vals: impl Row) -> Option<Box<str>> {
224///     let mut string_iter = vals.iter().flatten();
225///     // ...
226/// }
227/// ```
228///
229/// See `risingwave_common::row::Row` for more details.
230///
231/// ## Functions Returning Strings
232///
233/// For functions that return varchar types, you can also use the writer style function signature to
234/// avoid memory copying and dynamic memory allocation:
235///
236/// ```ignore
237/// #[function("trim(varchar) -> varchar")]
238/// fn trim(s: &str, writer: &mut impl Write) {
239///     writer.write_str(s.trim()).unwrap();
240/// }
241/// ```
242///
243/// If errors may be returned, then the return value should be `Result<()>`:
244///
245/// ```ignore
246/// #[function("trim(varchar) -> varchar")]
247/// fn trim(s: &str, writer: &mut impl Write) -> Result<()> {
248///     writer.write_str(s.trim()).unwrap();
249///     Ok(())
250/// }
251/// ```
252///
253/// If null values may be returned, then the return value should be `Option<()>`:
254///
255/// ```ignore
256/// #[function("trim(varchar) -> varchar")]
257/// fn trim(s: &str, writer: &mut impl Write) -> Option<()> {
258///     if s.is_empty() {
259///         None
260///     } else {
261///         writer.write_str(s.trim()).unwrap();
262///         Some(())
263///     }
264/// }
265/// ```
266///
267/// ## Preprocessing Constant Arguments
268///
269/// When some input arguments of the function are constants, they can be preprocessed to avoid
270/// calculations every time the function is called.
271///
272/// A classic use case is regular expression matching:
273///
274/// ```ignore
275/// #[function(
276///     "regexp_match(varchar, varchar, varchar) -> varchar[]",
277///     prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
278/// )]
279/// fn regexp_match(text: &str, regex: &RegexpContext) -> ListValue {
280///     regex.captures(text).collect()
281/// }
282/// ```
283///
284/// The `prebuild` argument can be specified, and its value is a Rust expression `Type::method(...)`
285/// used to construct a new variable of `Type` from the input arguments of the function.
286/// Here `$1`, `$2` represent the second and third arguments of the function (indexed from 0),
287/// and their types are `&str`. In the Rust function signature, these positions of parameters will
288/// be omitted, replaced by an extra new variable at the end.
289///
290/// This macro generates two versions of the function. If all the input parameters that `prebuild`
291/// depends on are constants, it will precompute them during the build function. Otherwise, it will
292/// compute them for each input row during evaluation. This way, we support both constant and variable
293/// inputs while optimizing performance for constant inputs.
294///
295/// ## Context
296///
297/// If a function needs to obtain type information at runtime, you can add an `&Context` parameter to
298/// the function signature. For example:
299///
300/// ```ignore
301/// #[function("foo(int32) -> int64")]
302/// fn foo(a: i32, ctx: &Context) -> i64 {
303///    assert_eq!(ctx.arg_types[0], DataType::Int32);
304///    assert_eq!(ctx.return_type, DataType::Int64);
305///    // ...
306/// }
307/// ```
308///
309/// ## Async Function
310///
311/// Functions can be asynchronous.
312///
313/// ```ignore
314/// #[function("pg_sleep(float64)")]
315/// async fn pg_sleep(second: F64) {
316///     tokio::time::sleep(Duration::from_secs_f64(second.0)).await;
317/// }
318/// ```
319///
320/// Asynchronous functions will be evaluated on rows sequentially.
321///
322/// # Table Function
323///
324/// A table function is a special kind of function that can return multiple values instead of just
325/// one. Its function signature must include the `setof` keyword, and the Rust function should
326/// return an iterator of the form `impl Iterator<Item = T>` or its derived types.
327///
328/// For example:
329/// ```ignore
330/// #[function("generate_series(int32, int32) -> setof int32")]
331/// fn generate_series(start: i32, stop: i32) -> impl Iterator<Item = i32> {
332///     start..=stop
333/// }
334/// ```
335///
336/// Likewise, the return value `Iterator` can include `Option` or `Result` either internally or
337/// externally. For instance:
338///
339/// - `impl Iterator<Item = Result<T>>`
340/// - `Result<impl Iterator<Item = T>>`
341/// - `Result<impl Iterator<Item = Result<Option<T>>>>`
342///
343/// Currently, table function arguments do not support the `Option` type. That is, the function will
344/// only be invoked when all arguments are not null.
345///
346/// # Registration and Invocation
347///
348/// Every function defined by `#[function]` is automatically registered in the global function
349/// table.
350///
351/// You can build expressions through the following functions:
352///
353/// ```ignore
354/// // scalar functions
355/// risingwave_expr::expr::build(...) -> BoxedExpression
356/// risingwave_expr::expr::build_from_prost(...) -> BoxedExpression
357/// // table functions
358/// risingwave_expr::table_function::build(...) -> BoxedTableFunction
359/// risingwave_expr::table_function::build_from_prost(...) -> BoxedTableFunction
360/// ```
361///
362/// Or get their metadata through the following functions:
363///
364/// ```ignore
365/// // scalar functions
366/// risingwave_expr::sig::func::FUNC_SIG_MAP::get(...)
367/// // table functions
368/// risingwave_expr::sig::table_function::FUNC_SIG_MAP::get(...)
369/// ```
370///
371/// # Appendix: Type Matrix
372///
373/// ## Base Types
374///
375/// | name        | SQL type           | owned type    | reference type     | primitive? |
376/// | ----------- | ------------------ | ------------- | ------------------ | ---------- |
377/// | boolean     | `boolean`          | `bool`        | `bool`             | yes        |
378/// | int2        | `smallint`         | `i16`         | `i16`              | yes        |
379/// | int4        | `integer`          | `i32`         | `i32`              | yes        |
380/// | int8        | `bigint`           | `i64`         | `i64`              | yes        |
381/// | int256      | `rw_int256`        | `Int256`      | `Int256Ref<'_>`    | no         |
382/// | float4      | `real`             | `F32`         | `F32`              | yes        |
383/// | float8      | `double precision` | `F64`         | `F64`              | yes        |
384/// | decimal     | `numeric`          | `Decimal`     | `Decimal`          | yes        |
385/// | serial      | `serial`           | `Serial`      | `Serial`           | yes        |
386/// | date        | `date`             | `Date`        | `Date`             | yes        |
387/// | time        | `time`             | `Time`        | `Time`             | yes        |
388/// | timestamp   | `timestamp`        | `Timestamp`   | `Timestamp`        | yes        |
389/// | timestamptz | `timestamptz`      | `Timestamptz` | `Timestamptz`      | yes        |
390/// | interval    | `interval`         | `Interval`    | `Interval`         | yes        |
391/// | varchar     | `varchar`          | `Box<str>`    | `&str`             | no         |
392/// | bytea       | `bytea`            | `Box<[u8]>`   | `&[u8]`            | no         |
393/// | jsonb       | `jsonb`            | `JsonbVal`    | `JsonbRef<'_>`     | no         |
394/// | any         | `any`              | `ScalarImpl`  | `ScalarRefImpl<'_>`| no         |
395///
396/// ## Composite Types
397///
398/// | name                   | SQL type             | owned type    | reference type     |
399/// | ---------------------- | -------------------- | ------------- | ------------------ |
400/// | anyarray               | `any[]`              | `ListValue`   | `ListRef<'_>`      |
401/// | struct                 | `record`             | `StructValue` | `StructRef<'_>`    |
402/// | T[^1][]                | `T[]`                | `ListValue`   | `ListRef<'_>`      |
403/// | struct<`name_T`[^1], ..> | `struct<name T, ..>` | `(T, ..)`     | `(&T, ..)`         |
404///
405/// [^1]: `T` could be any base type
406///
407/// [type matrix]: #appendix-type-matrix
408#[proc_macro_attribute]
409pub fn function(attr: TokenStream, item: TokenStream) -> TokenStream {
410    fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
411        let fn_attr: FunctionAttr = syn::parse(attr)?;
412        let user_fn: UserFunctionAttr = syn::parse(item.clone())?;
413
414        let mut tokens: TokenStream2 = item.into();
415        for attr in fn_attr.expand() {
416            tokens.extend(attr.generate_function_descriptor(&user_fn, false)?);
417        }
418        Ok(tokens)
419    }
420    match inner(attr, item) {
421        Ok(tokens) => tokens.into(),
422        Err(e) => e.to_compile_error().into(),
423    }
424}
425
426/// Different from `#[function]`, which implements the `Expression` trait for a rust scalar function,
427/// `#[build_function]` is used when you already implemented `Expression` manually.
428///
429/// The expected input is a "build" function:
430/// ```ignore
431/// fn(data_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression>
432/// ```
433///
434/// It generates the function descriptor using the "build" function and
435/// registers the description to the `FUNC_SIG_MAP`.
436#[proc_macro_attribute]
437pub fn build_function(attr: TokenStream, item: TokenStream) -> TokenStream {
438    fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
439        let fn_attr: FunctionAttr = syn::parse(attr)?;
440        let user_fn: UserFunctionAttr = syn::parse(item.clone())?;
441
442        let mut tokens: TokenStream2 = item.into();
443        for attr in fn_attr.expand() {
444            tokens.extend(attr.generate_function_descriptor(&user_fn, true)?);
445        }
446        Ok(tokens)
447    }
448    match inner(attr, item) {
449        Ok(tokens) => tokens.into(),
450        Err(e) => e.to_compile_error().into(),
451    }
452}
453
454#[proc_macro_attribute]
455pub fn aggregate(attr: TokenStream, item: TokenStream) -> TokenStream {
456    fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
457        let fn_attr: FunctionAttr = syn::parse(attr)?;
458        let user_fn: AggregateFnOrImpl = syn::parse(item.clone())?;
459
460        let mut tokens: TokenStream2 = item.into();
461        for attr in fn_attr.expand() {
462            tokens.extend(attr.generate_aggregate_descriptor(&user_fn, false)?);
463        }
464        Ok(tokens)
465    }
466    match inner(attr, item) {
467        Ok(tokens) => tokens.into(),
468        Err(e) => e.to_compile_error().into(),
469    }
470}
471
472#[proc_macro_attribute]
473pub fn build_aggregate(attr: TokenStream, item: TokenStream) -> TokenStream {
474    fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
475        let fn_attr: FunctionAttr = syn::parse(attr)?;
476        let user_fn: AggregateFnOrImpl = syn::parse(item.clone())?;
477
478        let mut tokens: TokenStream2 = item.into();
479        for attr in fn_attr.expand() {
480            tokens.extend(attr.generate_aggregate_descriptor(&user_fn, true)?);
481        }
482        Ok(tokens)
483    }
484    match inner(attr, item) {
485        Ok(tokens) => tokens.into(),
486        Err(e) => e.to_compile_error().into(),
487    }
488}
489
490#[derive(Debug, Clone, Default)]
491struct FunctionAttr {
492    /// Function name
493    name: String,
494    /// Input argument types
495    args: Vec<String>,
496    /// Return type
497    ret: String,
498    /// Whether it is a table function
499    is_table_function: bool,
500    /// Whether it is an append-only aggregate function
501    append_only: bool,
502    /// Optional function for batch evaluation.
503    batch_fn: Option<String>,
504    /// State type for aggregate function.
505    /// If not specified, it will be the same as return type.
506    state: Option<String>,
507    /// Initial state value for aggregate function.
508    /// If not specified, it will be NULL.
509    init_state: Option<String>,
510    /// Prebuild function for arguments.
511    /// This could be any Rust expression.
512    prebuild: Option<String>,
513    /// Type inference function.
514    type_infer: Option<String>,
515    /// Generic type.
516    generic: Option<String>,
517    /// Whether the function is volatile.
518    volatile: bool,
519    /// If true, the function is unavailable on the frontend.
520    deprecated: bool,
521    /// If true, the function is not implemented on the backend, but its signature is defined.
522    rewritten: bool,
523}
524
525/// Attributes from function signature `fn(..)`
526#[derive(Debug, Clone)]
527struct UserFunctionAttr {
528    /// Function name
529    name: String,
530    /// Whether the function is async.
531    async_: bool,
532    /// Whether contains argument `&Context`.
533    context: bool,
534    /// Whether contains argument `&mut impl Write`.
535    write: bool,
536    /// Whether the last argument type is `retract: bool`.
537    retract: bool,
538    /// Whether each argument type is `Option<T>`.
539    args_option: Vec<bool>,
540    /// If the first argument type is `&mut T`, then `Some(T)`.
541    first_mut_ref_arg: Option<String>,
542    /// The return type kind.
543    return_type_kind: ReturnTypeKind,
544    /// The kind of inner type `T` in `impl Iterator<Item = T>`
545    iterator_item_kind: Option<ReturnTypeKind>,
546    /// The core return type without `Option` or `Result`.
547    core_return_type: String,
548    /// The number of generic types.
549    generic: usize,
550    /// The span of return type.
551    return_type_span: proc_macro2::Span,
552}
553
554#[derive(Debug, Clone)]
555struct AggregateImpl {
556    struct_name: String,
557    accumulate: UserFunctionAttr,
558    retract: Option<UserFunctionAttr>,
559    #[allow(dead_code)] // TODO(wrj): add merge to trait
560    merge: Option<UserFunctionAttr>,
561    finalize: Option<UserFunctionAttr>,
562    create_state: Option<UserFunctionAttr>,
563    #[allow(dead_code)] // TODO(wrj): support encode
564    encode_state: Option<UserFunctionAttr>,
565    #[allow(dead_code)] // TODO(wrj): support decode
566    decode_state: Option<UserFunctionAttr>,
567}
568
569#[derive(Debug, Clone)]
570#[allow(clippy::large_enum_variant)]
571enum AggregateFnOrImpl {
572    /// A simple accumulate/retract function.
573    Fn(UserFunctionAttr),
574    /// A full impl block.
575    Impl(AggregateImpl),
576}
577
578impl AggregateFnOrImpl {
579    fn as_fn(&self) -> &UserFunctionAttr {
580        match self {
581            AggregateFnOrImpl::Fn(attr) => attr,
582            _ => panic!("expect fn"),
583        }
584    }
585
586    fn accumulate(&self) -> &UserFunctionAttr {
587        match self {
588            AggregateFnOrImpl::Fn(attr) => attr,
589            AggregateFnOrImpl::Impl(impl_) => &impl_.accumulate,
590        }
591    }
592
593    fn has_retract(&self) -> bool {
594        match self {
595            AggregateFnOrImpl::Fn(fn_) => fn_.retract,
596            AggregateFnOrImpl::Impl(impl_) => impl_.retract.is_some(),
597        }
598    }
599}
600
601#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
602enum ReturnTypeKind {
603    T,
604    Option,
605    Result,
606    ResultOption,
607}
608
609impl FunctionAttr {
610    /// Return a unique name that can be used as an identifier.
611    fn ident_name(&self) -> String {
612        format!("{}_{}_{}", self.name, self.args.join("_"), self.ret)
613            .replace("[]", "array")
614            .replace("...", "variadic")
615            .replace(['<', '>', ' ', ','], "_")
616            .replace("__", "_")
617    }
618}
619
620impl UserFunctionAttr {
621    /// Returns true if the function is like `fn(T1, T2, .., Tn) -> T`.
622    fn is_pure(&self) -> bool {
623        !self.async_
624            && !self.write
625            && !self.context
626            && self.args_option.iter().all(|b| !b)
627            && self.return_type_kind == ReturnTypeKind::T
628    }
629}
630
631/// Define the context variables which can be used by risingwave expressions.
632#[proc_macro]
633pub fn define_context(def: TokenStream) -> TokenStream {
634    fn inner(def: TokenStream) -> Result<TokenStream2> {
635        let attr: DefineContextAttr = syn::parse(def)?;
636        attr.r#gen()
637    }
638
639    match inner(def) {
640        Ok(tokens) => tokens.into(),
641        Err(e) => e.to_compile_error().into(),
642    }
643}
644
645/// Capture the context from the local context to the function impl.
646/// TODO: The macro will be merged to [`#[function(.., capture_context(..))]`](macro@function) later.
647///
648/// Currently, we should use the macro separately with a simple wrapper.
649#[proc_macro_attribute]
650pub fn capture_context(attr: TokenStream, item: TokenStream) -> TokenStream {
651    fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
652        let attr: CaptureContextAttr = syn::parse(attr)?;
653        let user_fn: ItemFn = syn::parse(item)?;
654
655        // Generate captured function
656        generate_captured_function(attr, user_fn)
657    }
658    match inner(attr, item) {
659        Ok(tokens) => tokens.into(),
660        Err(e) => e.to_compile_error().into(),
661    }
662}