risingwave_expr_macro/
gen.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Generate code for the functions.
16
17use itertools::Itertools;
18use proc_macro2::{Ident, Span};
19use quote::{format_ident, quote};
20
21use super::*;
22
23impl FunctionAttr {
24    /// Expands the wildcard in function arguments or return type.
25    pub fn expand(&self) -> Vec<Self> {
26        // handle variadic argument
27        if self
28            .args
29            .last()
30            .is_some_and(|arg| arg.starts_with("variadic"))
31        {
32            // expand:  foo(a, b, variadic anyarray)
33            // to:      foo(a, b, ...)
34            //        + foo_variadic(a, b, anyarray)
35            let mut attrs = Vec::new();
36            attrs.extend(
37                FunctionAttr {
38                    args: {
39                        let mut args = self.args.clone();
40                        *args.last_mut().unwrap() = "...".to_owned();
41                        args
42                    },
43                    ..self.clone()
44                }
45                .expand(),
46            );
47            attrs.extend(
48                FunctionAttr {
49                    name: format!("{}_variadic", self.name),
50                    args: {
51                        let mut args = self.args.clone();
52                        let last = args.last_mut().unwrap();
53                        *last = last.strip_prefix("variadic ").unwrap().into();
54                        args
55                    },
56                    ..self.clone()
57                }
58                .expand(),
59            );
60            return attrs;
61        }
62        let args = self.args.iter().map(|ty| types::expand_type_wildcard(ty));
63        let ret = types::expand_type_wildcard(&self.ret);
64        let mut attrs = Vec::new();
65        for (args, mut ret) in args.multi_cartesian_product().cartesian_product(ret) {
66            if ret == "auto" {
67                ret = types::min_compatible_type(&args);
68            }
69            let attr = FunctionAttr {
70                args: args.iter().map(|s| s.to_string()).collect(),
71                ret: ret.to_owned(),
72                ..self.clone()
73            };
74            attrs.push(attr);
75        }
76        attrs
77    }
78
79    /// Generate the type infer function: `fn(&[DataType]) -> Result<DataType>`
80    fn generate_type_infer_fn(&self) -> Result<TokenStream2> {
81        if let Some(func) = &self.type_infer {
82            if func == "unreachable" {
83                return Ok(
84                    quote! { |_| unreachable!("type inference for this function should be specially handled in frontend, and should not call sig.type_infer") },
85                );
86            }
87            // use the user defined type inference function
88            return Ok(func.parse().unwrap());
89        } else if self.ret == "any" {
90            // TODO: if there are multiple "any", they should be the same type
91            if let Some(i) = self.args.iter().position(|t| t == "any") {
92                // infer as the type of "any" argument
93                return Ok(quote! { |args| Ok(args[#i].clone()) });
94            }
95            if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
96                // infer as the element type of "anyarray" argument
97                return Ok(quote! { |args| Ok(args[#i].as_list_element_type().clone()) });
98            }
99        } else if self.ret == "anyarray" {
100            if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
101                // infer as the type of "anyarray" argument
102                return Ok(quote! { |args| Ok(args[#i].clone()) });
103            }
104            if let Some(i) = self.args.iter().position(|t| t == "any") {
105                // infer as the array type of "any" argument
106                return Ok(quote! { |args| Ok(DataType::List(Box::new(args[#i].clone()))) });
107            }
108        } else if self.ret == "struct" {
109            if let Some(i) = self.args.iter().position(|t| t == "struct") {
110                // infer as the type of "struct" argument
111                return Ok(quote! { |args| Ok(args[#i].clone()) });
112            }
113        } else if self.ret == "anymap" {
114            if let Some(i) = self.args.iter().position(|t| t == "anymap") {
115                // infer as the type of "anymap" argument
116                return Ok(quote! { |args| Ok(args[#i].clone()) });
117            }
118        } else if self.ret == "vector" {
119            if let Some(i) = self.args.iter().position(|t| t == "vector") {
120                // infer as the type of "vector" argument
121                // Example usage: `last_value(*) -> auto`
122                return Ok(quote! { |args| Ok(args[#i].clone()) });
123            }
124        } else {
125            // the return type is fixed
126            let ty = data_type(&self.ret);
127            return Ok(quote! { |_| Ok(#ty) });
128        }
129        Err(Error::new(
130            Span::call_site(),
131            "type inference function cannot be automatically derived. You should provide: `type_infer = \"|args| Ok(...)\"`",
132        ))
133    }
134
135    /// Generate a descriptor (`FuncSign`) of the scalar or table function.
136    ///
137    /// The types of arguments and return value should not contain wildcard.
138    ///
139    /// # Arguments
140    /// `build_fn`: whether the user provided a function is a build function.
141    /// (from the `#[build_function]` macro)
142    pub fn generate_function_descriptor(
143        &self,
144        user_fn: &UserFunctionAttr,
145        build_fn: bool,
146    ) -> Result<TokenStream2> {
147        if self.is_table_function {
148            return self.generate_table_function_descriptor(user_fn, build_fn);
149        }
150        let name = self.name.clone();
151        let variadic = matches!(self.args.last(), Some(t) if t == "...");
152        let args = match variadic {
153            true => &self.args[..self.args.len() - 1],
154            false => &self.args[..],
155        }
156        .iter()
157        .map(|ty| sig_data_type(ty))
158        .collect_vec();
159        let ret = sig_data_type(&self.ret);
160
161        let pb_type = format_ident!("{}", utils::to_camel_case(&name));
162        let ctor_name = format_ident!("{}", self.ident_name());
163        let build_fn = if build_fn {
164            let name = format_ident!("{}", user_fn.name);
165            quote! { #name }
166        } else if self.rewritten {
167            quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
168        } else {
169            // This is the core logic for `#[function]`
170            self.generate_build_scalar_function(user_fn, true)?
171        };
172        let type_infer_fn = self.generate_type_infer_fn()?;
173        let deprecated = self.deprecated;
174
175        Ok(quote! {
176            #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
177            fn #ctor_name() -> risingwave_expr::sig::FuncSign {
178                use risingwave_common::types::{DataType, DataTypeName};
179                use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
180
181                FuncSign {
182                    name: risingwave_pb::expr::expr_node::Type::#pb_type.into(),
183                    inputs_type: vec![#(#args),*],
184                    variadic: #variadic,
185                    ret_type: #ret,
186                    build: FuncBuilder::Scalar(#build_fn),
187                    type_infer: #type_infer_fn,
188                    deprecated: #deprecated,
189                }
190            }
191        })
192    }
193
194    /// Generate a build function for the scalar function.
195    ///
196    /// If `optimize_const` is true, the function will be optimized for constant arguments,
197    /// and fallback to the general version if any argument is not constant.
198    fn generate_build_scalar_function(
199        &self,
200        user_fn: &UserFunctionAttr,
201        optimize_const: bool,
202    ) -> Result<TokenStream2> {
203        let variadic = matches!(self.args.last(), Some(t) if t == "...");
204        let num_args = self.args.len() - if variadic { 1 } else { 0 };
205        let fn_name = format_ident!("{}", user_fn.name);
206        let struct_name = match optimize_const {
207            true => format_ident!("{}OptimizeConst", utils::to_camel_case(&self.ident_name())),
208            false => format_ident!("{}", utils::to_camel_case(&self.ident_name())),
209        };
210
211        // we divide all arguments into two groups: prebuilt and non-prebuilt.
212        // prebuilt arguments are collected from the "prebuild" field.
213        // let's say we have a function with 3 arguments: [0, 1, 2]
214        // and the prebuild field contains "$1".
215        // then we have:
216        //     prebuilt_indices = [1]
217        //     non_prebuilt_indices = [0, 2]
218        //
219        // if the const argument optimization is enabled, prebuilt arguments are
220        // evaluated at build time, thus the children only contain non-prebuilt arguments:
221        //     children_indices = [0, 2]
222        // otherwise, the children contain all arguments:
223        //     children_indices = [0, 1, 2]
224
225        let prebuilt_indices = match &self.prebuild {
226            Some(s) => (0..num_args)
227                .filter(|i| s.contains(&format!("${i}")))
228                .collect_vec(),
229            None => vec![],
230        };
231        let non_prebuilt_indices = match &self.prebuild {
232            Some(s) => (0..num_args)
233                .filter(|i| !s.contains(&format!("${i}")))
234                .collect_vec(),
235            _ => (0..num_args).collect_vec(),
236        };
237        let children_indices = match optimize_const {
238            #[allow(clippy::redundant_clone)] // false-positive
239            true => non_prebuilt_indices.clone(),
240            false => (0..num_args).collect_vec(),
241        };
242
243        /// Return a list of identifiers with the given prefix and indices.
244        fn idents(prefix: &str, indices: &[usize]) -> Vec<Ident> {
245            indices
246                .iter()
247                .map(|i| format_ident!("{prefix}{i}"))
248                .collect()
249        }
250        let inputs = idents("i", &children_indices);
251        let prebuilt_inputs = idents("i", &prebuilt_indices);
252        let non_prebuilt_inputs = idents("i", &non_prebuilt_indices);
253        let array_refs = idents("array", &children_indices);
254        let arrays = idents("a", &children_indices);
255        let datums = idents("v", &children_indices);
256        let arg_arrays = children_indices
257            .iter()
258            .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
259        let arg_types = children_indices.iter().map(|i| {
260            types::ref_type(&self.args[*i])
261                .parse::<TokenStream2>()
262                .unwrap()
263        });
264        let annotation: TokenStream2 = match user_fn.core_return_type.as_str() {
265            // add type annotation for functions that return generic types
266            "T" | "T1" | "T2" | "T3" => format!(": Option<{}>", types::owned_type(&self.ret))
267                .parse()
268                .unwrap(),
269            _ => quote! {},
270        };
271        let ret_array_type = format_ident!("{}", types::array_type(&self.ret));
272        let builder_type = format_ident!("{}Builder", types::array_type(&self.ret));
273        let prebuilt_arg_type = match &self.prebuild {
274            Some(s) if optimize_const => s.split("::").next().unwrap().parse().unwrap(),
275            _ => quote! { () },
276        };
277        let prebuilt_arg_value = match &self.prebuild {
278            // example:
279            // prebuild = "RegexContext::new($1)"
280            // return = "RegexContext::new(i1)"
281            Some(s) => s
282                .replace('$', "i")
283                .parse()
284                .expect("invalid prebuild syntax"),
285            None => quote! { () },
286        };
287        let prebuild_const = if self.prebuild.is_some() && optimize_const {
288            let build_general = self.generate_build_scalar_function(user_fn, false)?;
289            quote! {{
290                let build_general = #build_general;
291                #(
292                    // try to evaluate constant for prebuilt arguments
293                    let #prebuilt_inputs = match children[#prebuilt_indices].eval_const() {
294                        Ok(s) => s,
295                        // prebuilt argument is not constant, fallback to general
296                        Err(_) => return build_general(return_type, children),
297                    };
298                    // get reference to the constant value
299                    let #prebuilt_inputs = match &#prebuilt_inputs {
300                        Some(s) => s.as_scalar_ref_impl().try_into()?,
301                        // the function should always return null if any const argument is null
302                        None => return Ok(Box::new(risingwave_expr::expr::LiteralExpression::new(
303                            return_type,
304                            None,
305                        ))),
306                    };
307                )*
308                #prebuilt_arg_value
309            }}
310        } else {
311            quote! { () }
312        };
313
314        // ensure the number of children matches the number of arguments
315        let check_children = match variadic {
316            true => quote! { risingwave_expr::ensure!(children.len() >= #num_args); },
317            false => quote! { risingwave_expr::ensure!(children.len() == #num_args); },
318        };
319
320        // evaluate variadic arguments in `eval`
321        let eval_variadic = variadic.then(|| {
322            quote! {
323                let mut columns = Vec::with_capacity(self.children.len() - #num_args);
324                for child in &self.children[#num_args..] {
325                    columns.push(child.eval(input).await?);
326                }
327                let variadic_input = DataChunk::new(columns, input.visibility().clone());
328            }
329        });
330        // evaluate variadic arguments in `eval_row`
331        let eval_row_variadic = variadic.then(|| {
332            quote! {
333                let mut row = Vec::with_capacity(self.children.len() - #num_args);
334                for child in &self.children[#num_args..] {
335                    row.push(child.eval_row(input).await?);
336                }
337                let variadic_row = OwnedRow::new(row);
338            }
339        });
340
341        let generic = (self.ret == "boolean" && user_fn.generic == 3).then(|| {
342            // XXX: for generic compare functions, we need to specify the compatible type
343            let compatible_type = types::ref_type(types::min_compatible_type(&self.args))
344                .parse::<TokenStream2>()
345                .unwrap();
346            quote! { ::<_, _, #compatible_type> }
347        });
348        let prebuilt_arg = match (&self.prebuild, optimize_const) {
349            // use the prebuilt argument
350            (Some(_), true) => quote! { &self.prebuilt_arg, },
351            // build the argument on site
352            (Some(_), false) => quote! { &#prebuilt_arg_value, },
353            // no prebuilt argument
354            (None, _) => quote! {},
355        };
356        let variadic_args = variadic.then(|| quote! { &variadic_row, });
357        let context = user_fn.context.then(|| quote! { &self.context, });
358        let writer = user_fn.write.then(|| quote! { &mut writer, });
359        let await_ = user_fn.async_.then(|| quote! { .await });
360
361        let record_error = {
362            // Uniform arguments into `DatumRef`.
363            #[allow(clippy::disallowed_methods)] // allow zip
364            let inputs_args = inputs
365                .iter()
366                .zip(user_fn.args_option.iter())
367                .map(|(input, opt)| {
368                    if *opt {
369                        quote! { #input.map(|s| ScalarRefImpl::from(s)) }
370                    } else {
371                        quote! { Some(ScalarRefImpl::from(#input)) }
372                    }
373                });
374            let inputs_args = quote! {
375                let args: &[DatumRef<'_>] = &[#(#inputs_args),*];
376                let args = args.iter().copied();
377            };
378            let var_args = variadic.then(|| {
379                quote! {
380                    let args = args.chain(variadic_row.iter());
381                }
382            });
383
384            quote! {
385                #inputs_args
386                #var_args
387                errors.push(ExprError::function(
388                    stringify!(#fn_name),
389                    args,
390                    e,
391                ));
392            }
393        };
394
395        // call the user defined function
396        // inputs: [ Option<impl ScalarRef> ]
397        let mut output = quote! { #fn_name #generic(
398            #(#non_prebuilt_inputs,)*
399            #variadic_args
400            #prebuilt_arg
401            #context
402            #writer
403        ) #await_ };
404        // handle error if the function returns `Result`
405        // wrap a `Some` if the function doesn't return `Option`
406        output = match user_fn.return_type_kind {
407            // XXX: we don't support void type yet. return null::int for now.
408            _ if self.ret == "void" => quote! { { #output; Option::<i32>::None } },
409            ReturnTypeKind::T => quote! { Some(#output) },
410            ReturnTypeKind::Option => output,
411            ReturnTypeKind::Result => quote! {
412                match #output {
413                    Ok(x) => Some(x),
414                    Err(e) => {
415                        #record_error
416                        None
417                    }
418                }
419            },
420            ReturnTypeKind::ResultOption => quote! {
421                match #output {
422                    Ok(x) => x,
423                    Err(e) => {
424                        #record_error
425                        None
426                    }
427                }
428            },
429        };
430        // if user function accepts non-option arguments, we assume the function
431        // returns null on null input, so we need to unwrap the inputs before calling.
432        if self.prebuild.is_some() {
433            output = quote! {
434                match (#(#inputs,)*) {
435                    (#(Some(#inputs),)*) => #output,
436                    _ => None,
437                }
438            };
439        } else {
440            #[allow(clippy::disallowed_methods)] // allow zip
441            let some_inputs = inputs
442                .iter()
443                .zip(user_fn.args_option.iter())
444                .map(|(input, opt)| {
445                    if *opt {
446                        quote! { #input }
447                    } else {
448                        quote! { Some(#input) }
449                    }
450                });
451            output = quote! {
452                match (#(#inputs,)*) {
453                    (#(#some_inputs,)*) => #output,
454                    _ => None,
455                }
456            };
457        };
458        // now the `output` is: Option<impl ScalarRef or Scalar>
459        let append_output = match user_fn.write {
460            true => quote! {{
461                let mut writer = builder.writer().begin();
462                if #output.is_some() {
463                    writer.finish();
464                } else {
465                    drop(writer);
466                    builder.append_null();
467                }
468            }},
469            false if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
470                builder.append(#output.as_ref().map(|s| s.as_ref()));
471            },
472            false => quote! {
473                let output #annotation = #output;
474                builder.append(output.as_ref().map(|s| s.as_scalar_ref()));
475            },
476        };
477        // the output expression in `eval_row`
478        let row_output = match user_fn.write {
479            true => quote! {{
480                let mut writer = String::new();
481                #output.map(|_| writer.into())
482            }},
483            false if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
484                #output.map(|s| s.as_ref().into())
485            },
486            false => quote! {{
487                let output #annotation = #output;
488                output.map(|s| s.into())
489            }},
490        };
491        // the main body in `eval`
492        let eval = if let Some(batch_fn) = &self.batch_fn {
493            assert!(
494                !variadic,
495                "customized batch function is not supported for variadic functions"
496            );
497            // user defined batch function
498            let fn_name = format_ident!("{}", batch_fn);
499            quote! {
500                let c = #fn_name(#(#arrays),*);
501                Arc::new(c.into())
502            }
503        } else if (types::is_primitive(&self.ret) || self.ret == "boolean")
504            && user_fn.is_pure()
505            && !variadic
506            && self.prebuild.is_none()
507        {
508            // SIMD optimization for primitive types
509            match self.args.len() {
510                0 => quote! {
511                    let c = #ret_array_type::from_iter_bitmap(
512                        std::iter::repeat_with(|| #fn_name()).take(input.capacity())
513                        Bitmap::ones(input.capacity()),
514                    );
515                    Arc::new(c.into())
516                },
517                1 => quote! {
518                    let c = #ret_array_type::from_iter_bitmap(
519                        a0.raw_iter().map(|a| #fn_name(a)),
520                        a0.null_bitmap().clone()
521                    );
522                    Arc::new(c.into())
523                },
524                2 => quote! {
525                    // allow using `zip` for performance
526                    #[allow(clippy::disallowed_methods)]
527                    let c = #ret_array_type::from_iter_bitmap(
528                        a0.raw_iter()
529                            .zip(a1.raw_iter())
530                            .map(|(a, b)| #fn_name #generic(a, b)),
531                        a0.null_bitmap() & a1.null_bitmap(),
532                    );
533                    Arc::new(c.into())
534                },
535                n => todo!("SIMD optimization for {n} arguments"),
536            }
537        } else {
538            // no optimization
539            let let_variadic = variadic.then(|| {
540                quote! {
541                    let variadic_row = variadic_input.row_at_unchecked_vis(i);
542                }
543            });
544            quote! {
545                let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone());
546
547                if input.is_compacted() {
548                    for i in 0..input.capacity() {
549                        #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
550                        #let_variadic
551                        #append_output
552                    }
553                } else {
554                    for i in 0..input.capacity() {
555                        if unsafe { !input.visibility().is_set_unchecked(i) } {
556                            builder.append_null();
557                            continue;
558                        }
559                        #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
560                        #let_variadic
561                        #append_output
562                    }
563                }
564                Arc::new(builder.finish().into())
565            }
566        };
567
568        Ok(quote! {
569            |return_type: DataType, children: Vec<risingwave_expr::expr::BoxedExpression>|
570                -> risingwave_expr::Result<risingwave_expr::expr::BoxedExpression>
571            {
572                use std::sync::Arc;
573                use risingwave_common::array::*;
574                use risingwave_common::types::*;
575                use risingwave_common::bitmap::Bitmap;
576                use risingwave_common::row::OwnedRow;
577                use risingwave_common::util::iter_util::ZipEqFast;
578
579                use risingwave_expr::expr::{Context, BoxedExpression};
580                use risingwave_expr::{ExprError, Result};
581                use risingwave_expr::codegen::*;
582
583                #check_children
584                let prebuilt_arg = #prebuild_const;
585                let context = Context {
586                    return_type,
587                    arg_types: children.iter().map(|c| c.return_type()).collect(),
588                    variadic: #variadic,
589                };
590
591                #[derive(Debug)]
592                struct #struct_name {
593                    context: Context,
594                    children: Vec<BoxedExpression>,
595                    prebuilt_arg: #prebuilt_arg_type,
596                }
597                #[async_trait]
598                impl risingwave_expr::expr::Expression for #struct_name {
599                    fn return_type(&self) -> DataType {
600                        self.context.return_type.clone()
601                    }
602                    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
603                        #(
604                            let #array_refs = self.children[#children_indices].eval(input).await?;
605                            let #arrays: &#arg_arrays = #array_refs.as_ref().into();
606                        )*
607                        #eval_variadic
608                        let mut errors = vec![];
609                        let array = { #eval };
610                        if errors.is_empty() {
611                            Ok(array)
612                        } else {
613                            Err(ExprError::Multiple(array, errors.into()))
614                        }
615                    }
616                    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
617                        #(
618                            let #datums = self.children[#children_indices].eval_row(input).await?;
619                            let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap());
620                        )*
621                        #eval_row_variadic
622                        let mut errors: Vec<ExprError> = vec![];
623                        let output = #row_output;
624                        if let Some(err) = errors.into_iter().next() {
625                            Err(err.into())
626                        } else {
627                            Ok(output)
628                        }
629                    }
630                }
631
632                Ok(Box::new(#struct_name {
633                    context,
634                    children,
635                    prebuilt_arg,
636                }))
637            }
638        })
639    }
640
641    /// Generate a descriptor of the aggregate function.
642    ///
643    /// The types of arguments and return value should not contain wildcard.
644    /// `user_fn` could be either `fn` or `impl`.
645    /// If `build_fn` is true, `user_fn` must be a `fn` that builds the aggregate function.
646    pub fn generate_aggregate_descriptor(
647        &self,
648        user_fn: &AggregateFnOrImpl,
649        build_fn: bool,
650    ) -> Result<TokenStream2> {
651        let name = self.name.clone();
652
653        let mut args = Vec::with_capacity(self.args.len());
654        for ty in &self.args {
655            args.push(sig_data_type(ty));
656        }
657        let ret = sig_data_type(&self.ret);
658        let state_type = match &self.state {
659            Some(ty) if ty != "ref" => {
660                let ty = data_type(ty);
661                quote! { Some(#ty) }
662            }
663            _ => quote! { None },
664        };
665        let append_only = match build_fn {
666            false => !user_fn.has_retract(),
667            true => self.append_only,
668        };
669
670        let pb_kind = format_ident!("{}", utils::to_camel_case(&name));
671        let ctor_name = match append_only {
672            false => format_ident!("{}", self.ident_name()),
673            true => format_ident!("{}_append_only", self.ident_name()),
674        };
675        let build_fn = if build_fn {
676            let name = format_ident!("{}", user_fn.as_fn().name);
677            quote! { #name }
678        } else if self.rewritten {
679            quote! { |_| Err(ExprError::UnsupportedFunction(#name.into())) }
680        } else {
681            self.generate_agg_build_fn(user_fn)?
682        };
683        let build_retractable = match append_only {
684            true => quote! { None },
685            false => quote! { Some(#build_fn) },
686        };
687        let build_append_only = match append_only {
688            false => quote! { None },
689            true => quote! { Some(#build_fn) },
690        };
691        let retractable_state_type = match append_only {
692            true => quote! { None },
693            false => state_type.clone(),
694        };
695        let append_only_state_type = match append_only {
696            false => quote! { None },
697            true => state_type,
698        };
699        let type_infer_fn = self.generate_type_infer_fn()?;
700        let deprecated = self.deprecated;
701
702        Ok(quote! {
703            #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
704            fn #ctor_name() -> risingwave_expr::sig::FuncSign {
705                use risingwave_common::types::{DataType, DataTypeName};
706                use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
707
708                FuncSign {
709                    name: risingwave_pb::expr::agg_call::PbKind::#pb_kind.into(),
710                    inputs_type: vec![#(#args),*],
711                    variadic: false,
712                    ret_type: #ret,
713                    build: FuncBuilder::Aggregate {
714                        retractable: #build_retractable,
715                        append_only: #build_append_only,
716                        retractable_state_type: #retractable_state_type,
717                        append_only_state_type: #append_only_state_type,
718                    },
719                    type_infer: #type_infer_fn,
720                    deprecated: #deprecated,
721                }
722            }
723        })
724    }
725
726    /// Generate build function for aggregate function.
727    fn generate_agg_build_fn(&self, user_fn: &AggregateFnOrImpl) -> Result<TokenStream2> {
728        // If the first argument of the aggregate function is of type `&mut T`,
729        // we assume it is a user defined state type.
730        let custom_state = user_fn.accumulate().first_mut_ref_arg.as_ref();
731        let state_type: TokenStream2 = match (custom_state, &self.state) {
732            (Some(s), _) => s.parse().unwrap(),
733            (_, Some(state)) if state == "ref" => types::ref_type(&self.ret).parse().unwrap(),
734            (_, Some(state)) if state != "ref" => types::owned_type(state).parse().unwrap(),
735            _ => types::owned_type(&self.ret).parse().unwrap(),
736        };
737        let let_arrays = self
738            .args
739            .iter()
740            .enumerate()
741            .map(|(i, arg)| {
742                let array = format_ident!("a{i}");
743                let array_type: TokenStream2 = types::array_type(arg).parse().unwrap();
744                quote! {
745                    let #array: &#array_type = input.column_at(#i).as_ref().into();
746                }
747            })
748            .collect_vec();
749        let let_values = (0..self.args.len())
750            .map(|i| {
751                let v = format_ident!("v{i}");
752                let a = format_ident!("a{i}");
753                quote! { let #v = unsafe { #a.value_at_unchecked(row_id) }; }
754            })
755            .collect_vec();
756        let downcast_state = if custom_state.is_some() {
757            quote! { let mut state: &mut #state_type = state0.downcast_mut(); }
758        } else if let Some(s) = &self.state
759            && s == "ref"
760        {
761            quote! { let mut state: Option<#state_type> = state0.as_datum_mut().as_ref().map(|x| x.as_scalar_ref_impl().try_into().unwrap()); }
762        } else {
763            quote! { let mut state: Option<#state_type> = state0.as_datum_mut().take().map(|s| s.try_into().unwrap()); }
764        };
765        let restore_state = if custom_state.is_some() {
766            quote! {}
767        } else if let Some(s) = &self.state
768            && s == "ref"
769        {
770            quote! { *state0.as_datum_mut() = state.map(|x| x.to_owned_scalar().into()); }
771        } else {
772            quote! { *state0.as_datum_mut() = state.map(|s| s.into()); }
773        };
774        let create_state = if custom_state.is_some() {
775            quote! {
776                fn create_state(&self) -> Result<AggregateState> {
777                    Ok(AggregateState::Any(Box::<#state_type>::default()))
778                }
779            }
780        } else if let Some(state) = &self.init_state {
781            let state: TokenStream2 = state.parse().unwrap();
782            quote! {
783                fn create_state(&self) -> Result<AggregateState> {
784                    Ok(AggregateState::Datum(Some(#state.into())))
785                }
786            }
787        } else {
788            // by default: `AggregateState::Datum(None)`
789            quote! {}
790        };
791        let args = (0..self.args.len()).map(|i| format_ident!("v{i}"));
792        let args = quote! { #(#args,)* };
793        let panic_on_retract = {
794            let msg = format!(
795                "attempt to retract on aggregate function {}, but it is append-only",
796                self.name
797            );
798            quote! { assert_eq!(op, Op::Insert, #msg); }
799        };
800        let mut next_state = match user_fn {
801            AggregateFnOrImpl::Fn(f) => {
802                let context = f.context.then(|| quote! { &self.context, });
803                let fn_name = format_ident!("{}", f.name);
804                match f.retract {
805                    true => {
806                        quote! { #fn_name(state, #args matches!(op, Op::Delete | Op::UpdateDelete) #context) }
807                    }
808                    false => quote! {{
809                        #panic_on_retract
810                        #fn_name(state, #args #context)
811                    }},
812                }
813            }
814            AggregateFnOrImpl::Impl(i) => {
815                let retract = match i.retract {
816                    Some(_) => quote! { self.function.retract(state, #args) },
817                    None => panic_on_retract,
818                };
819                quote! {
820                    if matches!(op, Op::Delete | Op::UpdateDelete) {
821                        #retract
822                    } else {
823                        self.function.accumulate(state, #args)
824                    }
825                }
826            }
827        };
828        next_state = match user_fn.accumulate().return_type_kind {
829            ReturnTypeKind::T => quote! { Some(#next_state) },
830            ReturnTypeKind::Option => next_state,
831            ReturnTypeKind::Result => quote! { Some(#next_state?) },
832            ReturnTypeKind::ResultOption => quote! { #next_state? },
833        };
834        if user_fn.accumulate().args_option.iter().all(|b| !b) {
835            match self.args.len() {
836                0 => {
837                    next_state = quote! {
838                        match state {
839                            Some(state) => #next_state,
840                            None => state,
841                        }
842                    };
843                }
844                1 => {
845                    let first_state = if self.init_state.is_some() {
846                        // for count, the state will never be None
847                        quote! { unreachable!() }
848                    } else if let Some(s) = &self.state
849                        && s == "ref"
850                    {
851                        // for min/max/first/last, the state is the first value
852                        quote! { Some(v0) }
853                    } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
854                        && impl_.create_state.is_some()
855                    {
856                        // use user-defined create_state function
857                        quote! {{
858                            let state = self.function.create_state();
859                            #next_state
860                        }}
861                    } else {
862                        quote! {{
863                            let state = #state_type::default();
864                            #next_state
865                        }}
866                    };
867                    next_state = quote! {
868                        match (state, v0) {
869                            (Some(state), Some(v0)) => #next_state,
870                            (None, Some(v0)) => #first_state,
871                            (state, None) => state,
872                        }
873                    };
874                }
875                _ => todo!("multiple arguments are not supported for non-option function"),
876            }
877        }
878        let update_state = if custom_state.is_some() {
879            quote! { _ = #next_state; }
880        } else {
881            quote! { state = #next_state; }
882        };
883        let get_result = if custom_state.is_some() {
884            quote! { Ok(state.downcast_ref::<#state_type>().into()) }
885        } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
886            && impl_.finalize.is_some()
887        {
888            quote! {
889                let state = match state.as_datum() {
890                    Some(s) => s.as_scalar_ref_impl().try_into().unwrap(),
891                    None => return Ok(None),
892                };
893                Ok(Some(self.function.finalize(state).into()))
894            }
895        } else {
896            quote! { Ok(state.as_datum().clone()) }
897        };
898        let function_field = match user_fn {
899            AggregateFnOrImpl::Fn(_) => quote! {},
900            AggregateFnOrImpl::Impl(i) => {
901                let struct_name = format_ident!("{}", i.struct_name);
902                let generic = self.generic.as_ref().map(|g| {
903                    let g = format_ident!("{g}");
904                    quote! { <#g> }
905                });
906                quote! { function: #struct_name #generic, }
907            }
908        };
909        let function_new = match user_fn {
910            AggregateFnOrImpl::Fn(_) => quote! {},
911            AggregateFnOrImpl::Impl(i) => {
912                let struct_name = format_ident!("{}", i.struct_name);
913                let generic = self.generic.as_ref().map(|g| {
914                    let g = format_ident!("{g}");
915                    quote! { ::<#g> }
916                });
917                quote! { function: #struct_name #generic :: default(), }
918            }
919        };
920
921        Ok(quote! {
922            |agg| {
923                use std::collections::HashSet;
924                use std::ops::Range;
925                use risingwave_common::array::*;
926                use risingwave_common::types::*;
927                use risingwave_common::bail;
928                use risingwave_common::bitmap::Bitmap;
929                use risingwave_common_estimate_size::EstimateSize;
930
931                use risingwave_expr::expr::Context;
932                use risingwave_expr::Result;
933                use risingwave_expr::aggregate::AggregateState;
934                use risingwave_expr::codegen::async_trait;
935
936                let context = Context {
937                    return_type: agg.return_type.clone(),
938                    arg_types: agg.args.arg_types().to_owned(),
939                    variadic: false,
940                };
941
942                struct Agg {
943                    context: Context,
944                    #function_field
945                }
946
947                #[async_trait]
948                impl risingwave_expr::aggregate::AggregateFunction for Agg {
949                    fn return_type(&self) -> DataType {
950                        self.context.return_type.clone()
951                    }
952
953                    #create_state
954
955                    async fn update(&self, state0: &mut AggregateState, input: &StreamChunk) -> Result<()> {
956                        #(#let_arrays)*
957                        #downcast_state
958                        for row_id in input.visibility().iter_ones() {
959                            let op = unsafe { *input.ops().get_unchecked(row_id) };
960                            #(#let_values)*
961                            #update_state
962                        }
963                        #restore_state
964                        Ok(())
965                    }
966
967                    async fn update_range(&self, state0: &mut AggregateState, input: &StreamChunk, range: Range<usize>) -> Result<()> {
968                        assert!(range.end <= input.capacity());
969                        #(#let_arrays)*
970                        #downcast_state
971                        if input.is_compacted() {
972                            for row_id in range {
973                                let op = unsafe { *input.ops().get_unchecked(row_id) };
974                                #(#let_values)*
975                                #update_state
976                            }
977                        } else {
978                            for row_id in input.visibility().iter_ones() {
979                                if row_id < range.start {
980                                    continue;
981                                } else if row_id >= range.end {
982                                    break;
983                                }
984                                let op = unsafe { *input.ops().get_unchecked(row_id) };
985                                #(#let_values)*
986                                #update_state
987                            }
988                        }
989                        #restore_state
990                        Ok(())
991                    }
992
993                    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
994                        #get_result
995                    }
996                }
997
998                Ok(Box::new(Agg {
999                    context,
1000                    #function_new
1001                }))
1002            }
1003        })
1004    }
1005
1006    /// Generate a descriptor of the table function.
1007    ///
1008    /// The types of arguments and return value should not contain wildcard.
1009    fn generate_table_function_descriptor(
1010        &self,
1011        user_fn: &UserFunctionAttr,
1012        build_fn: bool,
1013    ) -> Result<TokenStream2> {
1014        let name = self.name.clone();
1015        let mut args = Vec::with_capacity(self.args.len());
1016        for ty in &self.args {
1017            args.push(sig_data_type(ty));
1018        }
1019        let ret = sig_data_type(&self.ret);
1020
1021        let pb_type = format_ident!("{}", utils::to_camel_case(&name));
1022        let ctor_name = format_ident!("{}", self.ident_name());
1023        let build_fn = if build_fn {
1024            let name = format_ident!("{}", user_fn.name);
1025            quote! { #name }
1026        } else if self.rewritten {
1027            quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
1028        } else {
1029            self.generate_build_table_function(user_fn)?
1030        };
1031        let type_infer_fn = self.generate_type_infer_fn()?;
1032        let deprecated = self.deprecated;
1033
1034        Ok(quote! {
1035            #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
1036            fn #ctor_name() -> risingwave_expr::sig::FuncSign {
1037                use risingwave_common::types::{DataType, DataTypeName};
1038                use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
1039
1040                FuncSign {
1041                    name: risingwave_pb::expr::table_function::Type::#pb_type.into(),
1042                    inputs_type: vec![#(#args),*],
1043                    variadic: false,
1044                    ret_type: #ret,
1045                    build: FuncBuilder::Table(#build_fn),
1046                    type_infer: #type_infer_fn,
1047                    deprecated: #deprecated,
1048                }
1049            }
1050        })
1051    }
1052
1053    fn generate_build_table_function(&self, user_fn: &UserFunctionAttr) -> Result<TokenStream2> {
1054        let num_args = self.args.len();
1055        let return_types = output_types(&self.ret);
1056        let fn_name = format_ident!("{}", user_fn.name);
1057        let struct_name = format_ident!("{}", utils::to_camel_case(&self.ident_name()));
1058        let arg_ids = (0..num_args)
1059            .filter(|i| match &self.prebuild {
1060                Some(s) => !s.contains(&format!("${i}")),
1061                None => true,
1062            })
1063            .collect_vec();
1064        let const_ids = (0..num_args).filter(|i| match &self.prebuild {
1065            Some(s) => s.contains(&format!("${i}")),
1066            None => false,
1067        });
1068        let inputs: Vec<_> = arg_ids.iter().map(|i| format_ident!("i{i}")).collect();
1069        let all_child: Vec<_> = (0..num_args).map(|i| format_ident!("child{i}")).collect();
1070        let const_child: Vec<_> = const_ids.map(|i| format_ident!("child{i}")).collect();
1071        let child: Vec<_> = arg_ids.iter().map(|i| format_ident!("child{i}")).collect();
1072        let array_refs: Vec<_> = arg_ids.iter().map(|i| format_ident!("array{i}")).collect();
1073        let arrays: Vec<_> = arg_ids.iter().map(|i| format_ident!("a{i}")).collect();
1074        let arg_arrays = arg_ids
1075            .iter()
1076            .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
1077        let outputs = (0..return_types.len())
1078            .map(|i| format_ident!("o{i}"))
1079            .collect_vec();
1080        let builders = (0..return_types.len())
1081            .map(|i| format_ident!("builder{i}"))
1082            .collect_vec();
1083        let builder_types = return_types
1084            .iter()
1085            .map(|ty| format_ident!("{}Builder", types::array_type(ty)))
1086            .collect_vec();
1087        let return_types = if return_types.len() == 1 {
1088            vec![quote! { self.context.return_type.clone() }]
1089        } else {
1090            (0..return_types.len())
1091                .map(|i| quote! { self.context.return_type.as_struct().types().nth(#i).unwrap().clone() })
1092                .collect()
1093        };
1094        #[allow(clippy::disallowed_methods)]
1095        let optioned_outputs = user_fn
1096            .core_return_type
1097            .split(',')
1098            .map(|t| t.contains("Option"))
1099            // example: "(Option<&str>, i32)" => [true, false]
1100            .zip(&outputs)
1101            .map(|(optional, o)| match optional {
1102                false => quote! { Some(#o.as_scalar_ref()) },
1103                true => quote! { #o.map(|o| o.as_scalar_ref()) },
1104            })
1105            .collect_vec();
1106        let build_value_array = if return_types.len() == 1 {
1107            quote! { let [value_array] = value_arrays; }
1108        } else {
1109            quote! {
1110                let value_array = StructArray::new(
1111                    self.context.return_type.as_struct().clone(),
1112                    value_arrays.to_vec(),
1113                    Bitmap::ones(len),
1114                ).into_ref();
1115            }
1116        };
1117        let context = user_fn.context.then(|| quote! { &self.context, });
1118        let prebuilt_arg = match &self.prebuild {
1119            Some(_) => quote! { &self.prebuilt_arg, },
1120            None => quote! {},
1121        };
1122        let prebuilt_arg_type = match &self.prebuild {
1123            Some(s) => s.split("::").next().unwrap().parse().unwrap(),
1124            None => quote! { () },
1125        };
1126        let prebuilt_arg_value = match &self.prebuild {
1127            Some(s) => s
1128                .replace('$', "child")
1129                .parse()
1130                .expect("invalid prebuild syntax"),
1131            None => quote! { () },
1132        };
1133        let iter = quote! { #fn_name(#(#inputs,)* #prebuilt_arg #context) };
1134        let mut iter = match user_fn.return_type_kind {
1135            ReturnTypeKind::T => quote! { #iter },
1136            ReturnTypeKind::Option => quote! { match #iter {
1137                Some(it) => it,
1138                None => continue,
1139            } },
1140            ReturnTypeKind::Result => quote! { match #iter {
1141                Ok(it) => it,
1142                Err(e) => {
1143                    index_builder.append(Some(i as i32));
1144                    #(#builders.append_null();)*
1145                    error_builder.append_display(Some(e.as_report()));
1146                    continue;
1147                }
1148            } },
1149            ReturnTypeKind::ResultOption => quote! { match #iter {
1150                Ok(Some(it)) => it,
1151                Ok(None) => continue,
1152                Err(e) => {
1153                    index_builder.append(Some(i as i32));
1154                    #(#builders.append_null();)*
1155                    error_builder.append_display(Some(e.as_report()));
1156                    continue;
1157                }
1158            } },
1159        };
1160        // if user function accepts non-option arguments, we assume the function
1161        // returns empty on null input, so we need to unwrap the inputs before calling.
1162        #[allow(clippy::disallowed_methods)] // allow zip
1163        let some_inputs = inputs
1164            .iter()
1165            .zip(user_fn.args_option.iter())
1166            .map(|(input, opt)| {
1167                if *opt {
1168                    quote! { #input }
1169                } else {
1170                    quote! { Some(#input) }
1171                }
1172            });
1173        iter = quote! {
1174            match (#(#inputs,)*) {
1175                (#(#some_inputs,)*) => #iter,
1176                _ => continue,
1177            }
1178        };
1179        let iterator_item_type = user_fn.iterator_item_kind.clone().ok_or_else(|| {
1180            Error::new(
1181                user_fn.return_type_span,
1182                "expect `impl Iterator` in return type",
1183            )
1184        })?;
1185        let append_output = match iterator_item_type {
1186            ReturnTypeKind::T => quote! {
1187                let (#(#outputs),*) = output;
1188                #(#builders.append(#optioned_outputs);)* error_builder.append_null();
1189            },
1190            ReturnTypeKind::Option => quote! { match output {
1191                Some((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1192                None => { #(#builders.append_null();)* error_builder.append_null(); }
1193            } },
1194            ReturnTypeKind::Result => quote! { match output {
1195                Ok((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1196                Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1197            } },
1198            ReturnTypeKind::ResultOption => quote! { match output {
1199                Ok(Some((#(#outputs),*))) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1200                Ok(None) => { #(#builders.append_null();)* error_builder.append_null(); }
1201                Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1202            } },
1203        };
1204
1205        Ok(quote! {
1206            |return_type, chunk_size, children| {
1207                use risingwave_common::array::*;
1208                use risingwave_common::types::*;
1209                use risingwave_common::bitmap::Bitmap;
1210                use risingwave_common::util::iter_util::ZipEqFast;
1211                use risingwave_expr::expr::{BoxedExpression, Context};
1212                use risingwave_expr::{Result, ExprError};
1213                use risingwave_expr::codegen::*;
1214
1215                risingwave_expr::ensure!(children.len() == #num_args);
1216
1217                let context = Context {
1218                    return_type: return_type.clone(),
1219                    arg_types: children.iter().map(|c| c.return_type()).collect(),
1220                    variadic: false,
1221                };
1222
1223                let mut iter = children.into_iter();
1224                #(let #all_child = iter.next().unwrap();)*
1225                #(
1226                    let #const_child = #const_child.eval_const()?;
1227                    let #const_child = match &#const_child {
1228                        Some(s) => s.as_scalar_ref_impl().try_into()?,
1229                        // the function should always return empty if any const argument is null
1230                        None => return Ok(risingwave_expr::table_function::empty(return_type)),
1231                    };
1232                )*
1233
1234                #[derive(Debug)]
1235                struct #struct_name {
1236                    context: Context,
1237                    chunk_size: usize,
1238                    #(#child: BoxedExpression,)*
1239                    prebuilt_arg: #prebuilt_arg_type,
1240                }
1241                #[async_trait]
1242                impl risingwave_expr::table_function::TableFunction for #struct_name {
1243                    fn return_type(&self) -> DataType {
1244                        self.context.return_type.clone()
1245                    }
1246                    async fn eval<'a>(&'a self, input: &'a DataChunk) -> BoxStream<'a, Result<DataChunk>> {
1247                        self.eval_inner(input)
1248                    }
1249                }
1250                impl #struct_name {
1251                    #[try_stream(boxed, ok = DataChunk, error = ExprError)]
1252                    async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
1253                        #(
1254                        let #array_refs = self.#child.eval(input).await?;
1255                        let #arrays: &#arg_arrays = #array_refs.as_ref().into();
1256                        )*
1257
1258                        let mut index_builder = I32ArrayBuilder::new(self.chunk_size);
1259                        #(let mut #builders = #builder_types::with_type(self.chunk_size, #return_types);)*
1260                        let mut error_builder = Utf8ArrayBuilder::new(self.chunk_size);
1261
1262                        for i in 0..input.capacity() {
1263                            if unsafe { !input.visibility().is_set_unchecked(i) } {
1264                                continue;
1265                            }
1266                            #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
1267                            for output in #iter {
1268                                index_builder.append(Some(i as i32));
1269                                #append_output
1270
1271                                if index_builder.len() == self.chunk_size {
1272                                    let len = index_builder.len();
1273                                    let index_array = std::mem::replace(&mut index_builder, I32ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1274                                    let value_arrays = [#(std::mem::replace(&mut #builders, #builder_types::with_type(self.chunk_size, #return_types)).finish().into_ref()),*];
1275                                    #build_value_array
1276                                    let error_array = std::mem::replace(&mut error_builder, Utf8ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1277                                    if error_array.null_bitmap().any() {
1278                                        yield DataChunk::new(vec![index_array, value_array, error_array], self.chunk_size);
1279                                    } else {
1280                                        yield DataChunk::new(vec![index_array, value_array], self.chunk_size);
1281                                    }
1282                                }
1283                            }
1284                        }
1285
1286                        if index_builder.len() > 0 {
1287                            let len = index_builder.len();
1288                            let index_array = index_builder.finish().into_ref();
1289                            let value_arrays = [#(#builders.finish().into_ref()),*];
1290                            #build_value_array
1291                            let error_array = error_builder.finish().into_ref();
1292                            if error_array.null_bitmap().any() {
1293                                yield DataChunk::new(vec![index_array, value_array, error_array], len);
1294                            } else {
1295                                yield DataChunk::new(vec![index_array, value_array], len);
1296                            }
1297                        }
1298                    }
1299                }
1300
1301                Ok(Box::new(#struct_name {
1302                    context,
1303                    chunk_size,
1304                    #(#child,)*
1305                    prebuilt_arg: #prebuilt_arg_value,
1306                }))
1307            }
1308        })
1309    }
1310}
1311
1312fn sig_data_type(ty: &str) -> TokenStream2 {
1313    match ty {
1314        "any" => quote! { SigDataType::Any },
1315        "anyarray" => quote! { SigDataType::AnyArray },
1316        "anymap" => quote! { SigDataType::AnyMap },
1317        "vector" => quote! { SigDataType::Vector },
1318        "struct" => quote! { SigDataType::AnyStruct },
1319        _ if ty.starts_with("struct") && ty.contains("any") => quote! { SigDataType::AnyStruct },
1320        _ => {
1321            let datatype = data_type(ty);
1322            quote! { SigDataType::Exact(#datatype) }
1323        }
1324    }
1325}
1326
1327fn data_type(ty: &str) -> TokenStream2 {
1328    if let Some(ty) = ty.strip_suffix("[]") {
1329        let inner_type = data_type(ty);
1330        return quote! { DataType::List(Box::new(#inner_type)) };
1331    }
1332    if ty.starts_with("struct<") {
1333        return quote! { DataType::Struct(#ty.parse().expect("invalid struct type")) };
1334    }
1335    let variant = format_ident!("{}", types::data_type(ty));
1336    // TODO: enable the check
1337    // assert!(
1338    //     !matches!(ty, "any" | "anyarray" | "anymap" | "struct"),
1339    //     "{ty}, {variant}"
1340    // );
1341
1342    quote! { DataType::#variant }
1343}
1344
1345/// Extract multiple output types.
1346///
1347/// ```ignore
1348/// output_types("int4") -> ["int4"]
1349/// output_types("struct<key varchar, value jsonb>") -> ["varchar", "jsonb"]
1350/// ```
1351fn output_types(ty: &str) -> Vec<&str> {
1352    if let Some(s) = ty.strip_prefix("struct<")
1353        && let Some(args) = s.strip_suffix('>')
1354    {
1355        args.split(',')
1356            .map(|s| s.split_whitespace().nth(1).unwrap())
1357            .collect()
1358    } else {
1359        vec![ty]
1360    }
1361}