risingwave_expr_macro/
lib.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::vec;
16
17use context::{CaptureContextAttr, DefineContextAttr, generate_captured_function};
18use proc_macro::TokenStream;
19use proc_macro_error::proc_macro_error;
20use proc_macro2::TokenStream as TokenStream2;
21use syn::{Error, ItemFn, Result};
22
23mod context;
24mod r#gen;
25mod parse;
26mod types;
27mod utils;
28
29/// Defining the RisingWave SQL function from a Rust function.
30///
31/// [Online version of this doc.](https://risingwavelabs.github.io/risingwave/rustdoc/risingwave_expr_macro/attr.function.html)
32///
33/// # Table of Contents
34///
35/// - [SQL Function Signature](#sql-function-signature)
36///     - [Multiple Function Definitions](#multiple-function-definitions)
37///     - [Type Expansion](#type-expansion)
38///     - [Automatic Type Inference](#automatic-type-inference)
39///     - [Custom Type Inference Function](#custom-type-inference-function)
40/// - [Rust Function Signature](#rust-function-signature)
41///     - [Nullable Arguments](#nullable-arguments)
42///     - [Return Value](#return-value)
43///     - [Variadic Function](#variadic-function)
44///     - [Optimization](#optimization)
45///     - [Writer Style Function](#writer-style-function)
46///     - [Preprocessing Constant Arguments](#preprocessing-constant-arguments)
47///     - [Context](#context)
48///     - [Async Function](#async-function)
49/// - [Table Function](#table-function)
50/// - [Registration and Invocation](#registration-and-invocation)
51/// - [Appendix: Type Matrix](#appendix-type-matrix)
52///
53/// The following example demonstrates a simple usage:
54///
55/// ```ignore
56/// #[function("add(int32, int32) -> int32")]
57/// fn add(x: i32, y: i32) -> i32 {
58///     x + y
59/// }
60/// ```
61///
62/// # SQL Function Signature
63///
64/// Each function must have a signature, specified in the `function("...")` part of the macro
65/// invocation. The signature follows this pattern:
66///
67/// ```text
68/// name ( [arg_types],* [...] ) [ -> [setof] return_type ]
69/// ```
70///
71/// Where `name` is the function name in `snake_case`, which must match the function name (in `UPPER_CASE`) defined
72/// in `proto/expr.proto`.
73///
74/// `arg_types` is a comma-separated list of argument types. The allowed data types are listed in
75/// in the `name` column of the appendix's [type matrix]. Wildcards or `auto` can also be used, as
76/// explained below. If the function is variadic, the last argument can be denoted as `...`.
77///
78/// When `setof` appears before the return type, this indicates that the function is a set-returning
79/// function (table function), meaning it can return multiple values instead of just one. For more
80/// details, see the section on table functions.
81///
82/// If no return type is specified, the function returns `void`. However, the void type is not
83/// supported in our type system, so it now returns a null value of type int.
84///
85/// ## Multiple Function Definitions
86///
87/// Multiple `#[function]` macros can be applied to a single generic Rust function to define
88/// multiple SQL functions of different types. For example:
89///
90/// ```ignore
91/// #[function("add(int16, int16) -> int16")]
92/// #[function("add(int32, int32) -> int32")]
93/// #[function("add(int64, int64) -> int64")]
94/// fn add<T: Add>(x: T, y: T) -> T {
95///     x + y
96/// }
97/// ```
98///
99/// ## Type Expansion with `*`
100///
101/// Types can be automatically expanded to multiple types using wildcards. Here are some examples:
102///
103/// - `*`: expands to all types.
104/// - `*int`: expands to int16, int32, int64.
105/// - `*float`: expands to float32, float64.
106///
107/// For instance, `#[function("cast(varchar) -> *int")]` will be expanded to the following three
108/// functions:
109///
110/// ```ignore
111/// #[function("cast(varchar) -> int16")]
112/// #[function("cast(varchar) -> int32")]
113/// #[function("cast(varchar) -> int64")]
114/// ```
115///
116/// Please note the difference between `*` and `any`: `*` will generate a function for each type,
117/// whereas `any` will only generate one function with a dynamic data type `Scalar`.
118/// This is similar to `impl T` and `dyn T` in Rust. The performance of using `*` would be much better than `any`.
119/// But we do not always prefer `*` due to better performance. In some cases, using `any` is more convenient.
120/// For example, in array functions, the element type of `ListValue` is `Scalar(Ref)Impl`.
121/// It is unnecessary to convert it from/into various `T`.
122///
123/// ## Automatic Type Inference with `auto`
124///
125/// Correspondingly, the return type can be denoted as `auto` to be automatically inferred based on
126/// the input types. It will be inferred as the _smallest type_ that can accommodate all input types.
127///
128/// For example, `#[function("add(*int, *int) -> auto")]` will be expanded to:
129///
130/// ```ignore
131/// #[function("add(int16, int16) -> int16")]
132/// #[function("add(int16, int32) -> int32")]
133/// #[function("add(int16, int64) -> int64")]
134/// #[function("add(int32, int16) -> int32")]
135/// ...
136/// ```
137///
138/// Especially when there is only one input argument, `auto` will be inferred as the type of that
139/// argument. For example, `#[function("neg(*int) -> auto")]` will be expanded to:
140///
141/// ```ignore
142/// #[function("neg(int16) -> int16")]
143/// #[function("neg(int32) -> int32")]
144/// #[function("neg(int64) -> int64")]
145/// ```
146///
147/// ## Custom Type Inference Function with `type_infer`
148///
149/// A few functions might have a return type that dynamically changes based on the input argument
150/// types, such as `unnest`. This is mainly for composite types like `anyarray`, `struct`, and `anymap`.
151///
152/// In such cases, the `type_infer` option can be used to specify a function to infer the return
153/// type based on the input argument types. Its function signature is
154///
155/// ```ignore
156/// fn(&[DataType]) -> Result<DataType>
157/// ```
158///
159/// For example:
160///
161/// ```ignore
162/// #[function(
163///     "unnest(anyarray) -> setof any",
164///     type_infer = "|args| Ok(args[0].unnest_list())"
165/// )]
166/// ```
167///
168/// This type inference function will be invoked at the frontend (`infer_type_with_sigmap`).
169///
170/// # Rust Function Signature
171///
172/// The `#[function]` macro can handle various types of Rust functions.
173///
174/// Each argument corresponds to the *reference type* in the [type matrix].
175///
176/// The return value type can be the *reference type* or *owned type* in the [type matrix].
177///
178/// For instance:
179///
180/// ```ignore
181/// #[function("trim_array(anyarray, int32) -> anyarray")]
182/// fn trim_array(array: ListRef<'_>, n: i32) -> ListValue {...}
183/// ```
184///
185/// ## Nullable Arguments
186///
187/// The functions above will only be called when all arguments are not null.
188/// It will return null if any argument is null.
189/// If null arguments need to be considered, the `Option` type can be used:
190///
191/// ```ignore
192/// #[function("trim_array(anyarray, int32) -> anyarray")]
193/// fn trim_array(array: ListRef<'_>, n: Option<i32>) -> ListValue {...}
194/// ```
195///
196/// This function will be called when `n` is null, but not when `array` is null.
197///
198/// ## Return `NULL`s and Errors
199///
200/// Similarly, the return value type can be one of the following:
201///
202/// - `T`: Indicates that a non-null value is always returned (for non-null inputs), and errors will not occur.
203/// - `Option<T>`: Indicates that a null value may be returned, but errors will not occur.
204/// - `Result<T>`: Indicates that an error may occur, but a null value will not be returned.
205/// - `Result<Option<T>>`: Indicates that a null value may be returned, and an error may also occur.
206///
207/// ## Optimization
208///
209/// When all input and output types of the function are *primitive type* (refer to the [type
210/// matrix]) and do not contain any Option or Result, the `#[function]` macro will automatically
211/// generate SIMD vectorized execution code.
212///
213/// Therefore, try to avoid returning `Option` and `Result` whenever possible.
214///
215/// ## Variadic Function
216///
217/// Variadic functions accept a `impl Row` input to represent tailing arguments.
218/// For example:
219///
220/// ```ignore
221/// #[function("concat_ws(varchar, ...) -> varchar")]
222/// fn concat_ws(sep: &str, vals: impl Row) -> Option<Box<str>> {
223///     let mut string_iter = vals.iter().flatten();
224///     // ...
225/// }
226/// ```
227///
228/// See `risingwave_common::row::Row` for more details.
229///
230/// ## Writer Style Function
231///
232/// For functions that return large or variable-length values (varchar, bytea, jsonb, anyarray),
233/// prefer a writer-style signature to avoid extra allocations and copying.
234///
235/// The evaluation framework uses builders and per-row writers:
236/// - Allocate a column builder for the result once.
237/// - For each row, create a writer backed by the builder (no per-row heap alloc).
238/// - Call the function with the writer so it writes the output directly into the builder.
239/// - If the call succeeds, finalize the writer; on error or null, rollback and append NULL.
240///
241/// This pattern minimizes heap allocations and copies for these types.
242///
243/// ```ignore
244/// #[function("trim(varchar) -> varchar")]
245/// fn trim(s: &str, writer: &mut impl std::fmt::Write) {
246///     writer.write_str(s.trim()).unwrap();
247/// }
248/// ```
249///
250/// If the function may return an error, use `Result<()>`:
251///
252/// ```ignore
253/// #[function("trim(varchar) -> varchar")]
254/// fn trim(s: &str, writer: &mut impl std::fmt::Write) -> Result<()> {
255///     writer.write_str(s.trim()).unwrap();
256///     Ok(())
257/// }
258/// ```
259///
260/// If the function may return NULL, use `Option<()>`:
261///
262/// ```ignore
263/// #[function("trim(varchar) -> varchar")]
264/// fn trim(s: &str, writer: &mut impl std::fmt::Write) -> Option<()> {
265///     if s.is_empty() {
266///         None
267///     } else {
268///         writer.write_str(s.trim()).unwrap();
269///         Some(())
270///     }
271/// }
272/// ```
273///
274/// Writer types:
275/// - For `varchar`: `&mut impl std::fmt::Write`
276/// - For `bytea`: `&mut impl std::io::Write`
277/// - For `jsonb`: `&mut jsonbb::Builder`
278/// - For `anyarray`: `&mut impl risingwave_common::array::ListWrite`
279///
280/// Note: Use fully-qualified trait paths (for example, `impl std::fmt::Write`).
281/// Partial or relative paths (such as `impl Write` or `impl ::std::fmt::Write`) are not recognized.
282///
283/// ## Preprocessing Constant Arguments
284///
285/// When some input arguments of the function are constants, they can be preprocessed to avoid
286/// calculations every time the function is called.
287///
288/// A classic use case is regular expression matching:
289///
290/// ```ignore
291/// #[function(
292///     "regexp_match(varchar, varchar, varchar) -> varchar[]",
293///     prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
294/// )]
295/// fn regexp_match(text: &str, regex: &RegexpContext) -> ListValue {
296///     regex.captures(text).collect()
297/// }
298/// ```
299///
300/// The `prebuild` argument can be specified, and its value is a Rust expression `Type::method(...)`
301/// used to construct a new variable of `Type` from the input arguments of the function.
302/// Here `$1`, `$2` represent the second and third arguments of the function (indexed from 0),
303/// and their types are `&str`. In the Rust function signature, these positions of parameters will
304/// be omitted, replaced by an extra new variable at the end.
305///
306/// This macro generates two versions of the function. If all the input parameters that `prebuild`
307/// depends on are constants, it will precompute them during the build function. Otherwise, it will
308/// compute them for each input row during evaluation. This way, we support both constant and variable
309/// inputs while optimizing performance for constant inputs.
310///
311/// ## Context
312///
313/// If a function needs to obtain type information at runtime, you can add an `&Context` parameter to
314/// the function signature. For example:
315///
316/// ```ignore
317/// #[function("foo(int32) -> int64")]
318/// fn foo(a: i32, ctx: &Context) -> i64 {
319///    assert_eq!(ctx.arg_types[0], DataType::Int32);
320///    assert_eq!(ctx.return_type, DataType::Int64);
321///    // ...
322/// }
323/// ```
324///
325/// ## Async Function
326///
327/// Functions can be asynchronous.
328///
329/// ```ignore
330/// #[function("pg_sleep(float64)")]
331/// async fn pg_sleep(second: F64) {
332///     tokio::time::sleep(Duration::from_secs_f64(second.0)).await;
333/// }
334/// ```
335///
336/// Asynchronous functions will be evaluated on rows sequentially.
337///
338/// # Table Function
339///
340/// A table function is a special kind of function that can return multiple values instead of just
341/// one. Its function signature must include the `setof` keyword, and the Rust function should
342/// return an iterator of the form `impl Iterator<Item = T>` or its derived types.
343///
344/// For example:
345/// ```ignore
346/// #[function("generate_series(int32, int32) -> setof int32")]
347/// fn generate_series(start: i32, stop: i32) -> impl Iterator<Item = i32> {
348///     start..=stop
349/// }
350/// ```
351///
352/// Likewise, the return value `Iterator` can include `Option` or `Result` either internally or
353/// externally. For instance:
354///
355/// - `impl Iterator<Item = Result<T>>`
356/// - `Result<impl Iterator<Item = T>>`
357/// - `Result<impl Iterator<Item = Result<Option<T>>>>`
358///
359/// Currently, table function arguments do not support the `Option` type. That is, the function will
360/// only be invoked when all arguments are not null.
361///
362/// # Registration and Invocation
363///
364/// Every function defined by `#[function]` is automatically registered in the global function
365/// table.
366///
367/// You can build expressions through the following functions:
368///
369/// ```ignore
370/// // scalar functions
371/// risingwave_expr::expr::build(...) -> BoxedExpression
372/// risingwave_expr::expr::build_from_prost(...) -> BoxedExpression
373/// // table functions
374/// risingwave_expr::table_function::build(...) -> BoxedTableFunction
375/// risingwave_expr::table_function::build_from_prost(...) -> BoxedTableFunction
376/// ```
377///
378/// Or get their metadata through the following functions:
379///
380/// ```ignore
381/// // scalar functions
382/// risingwave_expr::sig::func::FUNC_SIG_MAP::get(...)
383/// // table functions
384/// risingwave_expr::sig::table_function::FUNC_SIG_MAP::get(...)
385/// ```
386///
387/// # Appendix: Type Matrix
388///
389/// ## Base Types
390///
391/// | name        | SQL type           | owned type    | reference type     | primitive? |
392/// | ----------- | ------------------ | ------------- | ------------------ | ---------- |
393/// | boolean     | `boolean`          | `bool`        | `bool`             | yes        |
394/// | int2        | `smallint`         | `i16`         | `i16`              | yes        |
395/// | int4        | `integer`          | `i32`         | `i32`              | yes        |
396/// | int8        | `bigint`           | `i64`         | `i64`              | yes        |
397/// | int256      | `rw_int256`        | `Int256`      | `Int256Ref<'_>`    | no         |
398/// | float4      | `real`             | `F32`         | `F32`              | yes        |
399/// | float8      | `double precision` | `F64`         | `F64`              | yes        |
400/// | decimal     | `numeric`          | `Decimal`     | `Decimal`          | yes        |
401/// | serial      | `serial`           | `Serial`      | `Serial`           | yes        |
402/// | date        | `date`             | `Date`        | `Date`             | yes        |
403/// | time        | `time`             | `Time`        | `Time`             | yes        |
404/// | timestamp   | `timestamp`        | `Timestamp`   | `Timestamp`        | yes        |
405/// | timestamptz | `timestamptz`      | `Timestamptz` | `Timestamptz`      | yes        |
406/// | interval    | `interval`         | `Interval`    | `Interval`         | yes        |
407/// | varchar     | `varchar`          | `Box<str>`    | `&str`             | no         |
408/// | bytea       | `bytea`            | `Box<[u8]>`   | `&[u8]`            | no         |
409/// | jsonb       | `jsonb`            | `JsonbVal`    | `JsonbRef<'_>`     | no         |
410/// | any         | `any`              | `ScalarImpl`  | `ScalarRefImpl<'_>`| no         |
411///
412/// ## Composite Types
413///
414/// | name                   | SQL type             | owned type    | reference type     |
415/// | ---------------------- | -------------------- | ------------- | ------------------ |
416/// | anyarray               | `any[]`              | `ListValue`   | `ListRef<'_>`      |
417/// | struct                 | `record`             | `StructValue` | `StructRef<'_>`    |
418/// | T[^1][]                | `T[]`                | `ListValue`   | `ListRef<'_>`      |
419/// | struct<`name_T`[^1], ..> | `struct<name T, ..>` | `(T, ..)`     | `(&T, ..)`         |
420///
421/// [^1]: `T` could be any base type
422///
423/// [type matrix]: #appendix-type-matrix
424#[proc_macro_attribute]
425#[proc_macro_error]
426pub fn function(attr: TokenStream, item: TokenStream) -> TokenStream {
427    fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
428        let fn_attr: FunctionAttr = syn::parse(attr)?;
429        let user_fn: UserFunctionAttr = syn::parse(item.clone())?;
430
431        let mut tokens: TokenStream2 = item.into();
432        for attr in fn_attr.expand() {
433            tokens.extend(attr.generate_function_descriptor(&user_fn, false)?);
434        }
435        Ok(tokens)
436    }
437    match inner(attr, item) {
438        Ok(tokens) => tokens.into(),
439        Err(e) => e.to_compile_error().into(),
440    }
441}
442
443/// Different from `#[function]`, which implements the `Expression` trait for a rust scalar function,
444/// `#[build_function]` is used when you already implemented `Expression` manually.
445///
446/// The expected input is a "build" function:
447/// ```ignore
448/// fn(data_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression>
449/// ```
450///
451/// It generates the function descriptor using the "build" function and
452/// registers the description to the `FUNC_SIG_MAP`.
453#[proc_macro_attribute]
454pub fn build_function(attr: TokenStream, item: TokenStream) -> TokenStream {
455    fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
456        let fn_attr: FunctionAttr = syn::parse(attr)?;
457        let user_fn: UserFunctionAttr = syn::parse(item.clone())?;
458
459        let mut tokens: TokenStream2 = item.into();
460        for attr in fn_attr.expand() {
461            tokens.extend(attr.generate_function_descriptor(&user_fn, true)?);
462        }
463        Ok(tokens)
464    }
465    match inner(attr, item) {
466        Ok(tokens) => tokens.into(),
467        Err(e) => e.to_compile_error().into(),
468    }
469}
470
471#[proc_macro_attribute]
472pub fn aggregate(attr: TokenStream, item: TokenStream) -> TokenStream {
473    fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
474        let fn_attr: FunctionAttr = syn::parse(attr)?;
475        let user_fn: AggregateFnOrImpl = syn::parse(item.clone())?;
476
477        let mut tokens: TokenStream2 = item.into();
478        for attr in fn_attr.expand() {
479            tokens.extend(attr.generate_aggregate_descriptor(&user_fn, false)?);
480        }
481        Ok(tokens)
482    }
483    match inner(attr, item) {
484        Ok(tokens) => tokens.into(),
485        Err(e) => e.to_compile_error().into(),
486    }
487}
488
489#[proc_macro_attribute]
490pub fn build_aggregate(attr: TokenStream, item: TokenStream) -> TokenStream {
491    fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
492        let fn_attr: FunctionAttr = syn::parse(attr)?;
493        let user_fn: AggregateFnOrImpl = syn::parse(item.clone())?;
494
495        let mut tokens: TokenStream2 = item.into();
496        for attr in fn_attr.expand() {
497            tokens.extend(attr.generate_aggregate_descriptor(&user_fn, true)?);
498        }
499        Ok(tokens)
500    }
501    match inner(attr, item) {
502        Ok(tokens) => tokens.into(),
503        Err(e) => e.to_compile_error().into(),
504    }
505}
506
507#[derive(Debug, Clone, Default)]
508struct FunctionAttr {
509    /// Function name
510    name: String,
511    /// Input argument types
512    args: Vec<String>,
513    /// Return type
514    ret: String,
515    /// Whether it is a table function
516    is_table_function: bool,
517    /// Whether it is an append-only aggregate function
518    append_only: bool,
519    /// Optional function for batch evaluation.
520    batch_fn: Option<String>,
521    /// State type for aggregate function.
522    /// If not specified, it will be the same as return type.
523    state: Option<String>,
524    /// Initial state value for aggregate function.
525    /// If not specified, it will be NULL.
526    init_state: Option<String>,
527    /// Prebuild function for arguments.
528    /// This could be any Rust expression.
529    prebuild: Option<String>,
530    /// Type inference function.
531    type_infer: Option<String>,
532    /// Generic type.
533    generic: Option<String>,
534    /// Whether the function is volatile.
535    volatile: bool,
536    /// If true, the function is unavailable on the frontend.
537    deprecated: bool,
538    /// If true, the function is not implemented on the backend, but its signature is defined.
539    rewritten: bool,
540}
541
542/// Attributes from function signature `fn(..)`
543#[derive(Debug, Clone)]
544struct UserFunctionAttr {
545    /// Function name
546    name: String,
547    /// Whether the function is async.
548    async_: bool,
549    /// Whether contains argument `&Context`.
550    context: bool,
551    /// The writer type kind, if any, such as `impl std::fmt::Write`.
552    writer_type_kind: Option<WriterTypeKind>,
553    /// Whether the last argument type is `retract: bool`.
554    retract: bool,
555    /// Whether each argument type is `Option<T>`.
556    args_option: Vec<bool>,
557    /// If the first argument type is `&mut T`, then `Some(T)`.
558    first_mut_ref_arg: Option<String>,
559    /// The return type kind.
560    return_type_kind: ReturnTypeKind,
561    /// The kind of inner type `T` in `impl Iterator<Item = T>`
562    iterator_item_kind: Option<ReturnTypeKind>,
563    /// The core return type without `Option` or `Result`.
564    core_return_type: String,
565    /// The number of generic types.
566    generic: usize,
567    /// The span of return type.
568    return_type_span: proc_macro2::Span,
569}
570
571#[derive(Debug, Clone)]
572struct AggregateImpl {
573    struct_name: String,
574    accumulate: UserFunctionAttr,
575    retract: Option<UserFunctionAttr>,
576    #[allow(dead_code)] // TODO(wrj): add merge to trait
577    merge: Option<UserFunctionAttr>,
578    finalize: Option<UserFunctionAttr>,
579    create_state: Option<UserFunctionAttr>,
580    #[allow(dead_code)] // TODO(wrj): support encode
581    encode_state: Option<UserFunctionAttr>,
582    #[allow(dead_code)] // TODO(wrj): support decode
583    decode_state: Option<UserFunctionAttr>,
584}
585
586#[derive(Debug, Clone)]
587#[allow(clippy::large_enum_variant)]
588enum AggregateFnOrImpl {
589    /// A simple accumulate/retract function.
590    Fn(UserFunctionAttr),
591    /// A full impl block.
592    Impl(AggregateImpl),
593}
594
595impl AggregateFnOrImpl {
596    fn as_fn(&self) -> &UserFunctionAttr {
597        match self {
598            AggregateFnOrImpl::Fn(attr) => attr,
599            _ => panic!("expect fn"),
600        }
601    }
602
603    fn accumulate(&self) -> &UserFunctionAttr {
604        match self {
605            AggregateFnOrImpl::Fn(attr) => attr,
606            AggregateFnOrImpl::Impl(impl_) => &impl_.accumulate,
607        }
608    }
609
610    fn has_retract(&self) -> bool {
611        match self {
612            AggregateFnOrImpl::Fn(fn_) => fn_.retract,
613            AggregateFnOrImpl::Impl(impl_) => impl_.retract.is_some(),
614        }
615    }
616}
617
618#[derive(Debug, Clone, Copy)]
619enum WriterTypeKind {
620    FmtWrite,      // std::fmt::Write
621    IoWrite,       // std::io::Write
622    JsonbbBuilder, // jsonbb::Builder
623    ListWrite,     // risingwave_common::array::ListWrite
624}
625
626#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
627enum ReturnTypeKind {
628    T,
629    Option,
630    Result,
631    ResultOption,
632}
633
634impl FunctionAttr {
635    /// Return a unique name that can be used as an identifier.
636    fn ident_name(&self) -> String {
637        format!("{}_{}_{}", self.name, self.args.join("_"), self.ret)
638            .replace("[]", "array")
639            .replace("...", "variadic")
640            .replace(['<', '>', ' ', ','], "_")
641            .replace("__", "_")
642    }
643}
644
645impl UserFunctionAttr {
646    /// Returns true if the function is like `fn(T1, T2, .., Tn) -> T`.
647    fn is_pure(&self) -> bool {
648        !self.async_
649            && self.writer_type_kind.is_none()
650            && !self.context
651            && self.args_option.iter().all(|b| !b)
652            && self.return_type_kind == ReturnTypeKind::T
653    }
654}
655
656/// Define the context variables which can be used by risingwave expressions.
657#[proc_macro]
658pub fn define_context(def: TokenStream) -> TokenStream {
659    fn inner(def: TokenStream) -> Result<TokenStream2> {
660        let attr: DefineContextAttr = syn::parse(def)?;
661        attr.r#gen()
662    }
663
664    match inner(def) {
665        Ok(tokens) => tokens.into(),
666        Err(e) => e.to_compile_error().into(),
667    }
668}
669
670/// Capture the context from the local context to the function impl.
671/// TODO: The macro will be merged to [`#[function(.., capture_context(..))]`](macro@function) later.
672///
673/// Currently, we should use the macro separately with a simple wrapper.
674#[proc_macro_attribute]
675pub fn capture_context(attr: TokenStream, item: TokenStream) -> TokenStream {
676    fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
677        let attr: CaptureContextAttr = syn::parse(attr)?;
678        let user_fn: ItemFn = syn::parse(item)?;
679
680        // Generate captured function
681        generate_captured_function(attr, user_fn)
682    }
683    match inner(attr, item) {
684        Ok(tokens) => tokens.into(),
685        Err(e) => e.to_compile_error().into(),
686    }
687}