risingwave_expr_macro/
gen.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Generate code for the functions.
16
17use itertools::Itertools;
18use proc_macro2::{Ident, Span};
19use quote::{format_ident, quote};
20
21use super::*;
22
23impl FunctionAttr {
24    /// Expands the wildcard in function arguments or return type.
25    pub fn expand(&self) -> Vec<Self> {
26        // handle variadic argument
27        if self
28            .args
29            .last()
30            .is_some_and(|arg| arg.starts_with("variadic"))
31        {
32            // expand:  foo(a, b, variadic anyarray)
33            // to:      foo(a, b, ...)
34            //        + foo_variadic(a, b, anyarray)
35            let mut attrs = Vec::new();
36            attrs.extend(
37                FunctionAttr {
38                    args: {
39                        let mut args = self.args.clone();
40                        *args.last_mut().unwrap() = "...".to_owned();
41                        args
42                    },
43                    ..self.clone()
44                }
45                .expand(),
46            );
47            attrs.extend(
48                FunctionAttr {
49                    name: format!("{}_variadic", self.name),
50                    args: {
51                        let mut args = self.args.clone();
52                        let last = args.last_mut().unwrap();
53                        *last = last.strip_prefix("variadic ").unwrap().into();
54                        args
55                    },
56                    ..self.clone()
57                }
58                .expand(),
59            );
60            return attrs;
61        }
62        let args = self.args.iter().map(|ty| types::expand_type_wildcard(ty));
63        let ret = types::expand_type_wildcard(&self.ret);
64        let mut attrs = Vec::new();
65        for (args, mut ret) in args.multi_cartesian_product().cartesian_product(ret) {
66            if ret == "auto" {
67                ret = types::min_compatible_type(&args);
68            }
69            let attr = FunctionAttr {
70                args: args.iter().map(|s| s.to_string()).collect(),
71                ret: ret.to_owned(),
72                ..self.clone()
73            };
74            attrs.push(attr);
75        }
76        attrs
77    }
78
79    /// Generate the type infer function: `fn(&[DataType]) -> Result<DataType>`
80    fn generate_type_infer_fn(&self) -> Result<TokenStream2> {
81        if let Some(func) = &self.type_infer {
82            if func == "unreachable" {
83                return Ok(
84                    quote! { |_| unreachable!("type inference for this function should be specially handled in frontend, and should not call sig.type_infer") },
85                );
86            }
87            // use the user defined type inference function
88            return Ok(func.parse().unwrap());
89        } else if self.ret == "any" {
90            // TODO: if there are multiple "any", they should be the same type
91            if let Some(i) = self.args.iter().position(|t| t == "any") {
92                // infer as the type of "any" argument
93                return Ok(quote! { |args| Ok(args[#i].clone()) });
94            }
95            if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
96                // infer as the element type of "anyarray" argument
97                return Ok(quote! { |args| Ok(args[#i].as_list().clone()) });
98            }
99        } else if self.ret == "anyarray" {
100            if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
101                // infer as the type of "anyarray" argument
102                return Ok(quote! { |args| Ok(args[#i].clone()) });
103            }
104            if let Some(i) = self.args.iter().position(|t| t == "any") {
105                // infer as the array type of "any" argument
106                return Ok(quote! { |args| Ok(DataType::List(Box::new(args[#i].clone()))) });
107            }
108        } else if self.ret == "struct" {
109            if let Some(i) = self.args.iter().position(|t| t == "struct") {
110                // infer as the type of "struct" argument
111                return Ok(quote! { |args| Ok(args[#i].clone()) });
112            }
113        } else if self.ret == "anymap" {
114            if let Some(i) = self.args.iter().position(|t| t == "anymap") {
115                // infer as the type of "anymap" argument
116                return Ok(quote! { |args| Ok(args[#i].clone()) });
117            }
118        } else {
119            // the return type is fixed
120            let ty = data_type(&self.ret);
121            return Ok(quote! { |_| Ok(#ty) });
122        }
123        Err(Error::new(
124            Span::call_site(),
125            "type inference function cannot be automatically derived. You should provide: `type_infer = \"|args| Ok(...)\"`",
126        ))
127    }
128
129    /// Generate a descriptor (`FuncSign`) of the scalar or table function.
130    ///
131    /// The types of arguments and return value should not contain wildcard.
132    ///
133    /// # Arguments
134    /// `build_fn`: whether the user provided a function is a build function.
135    /// (from the `#[build_function]` macro)
136    pub fn generate_function_descriptor(
137        &self,
138        user_fn: &UserFunctionAttr,
139        build_fn: bool,
140    ) -> Result<TokenStream2> {
141        if self.is_table_function {
142            return self.generate_table_function_descriptor(user_fn, build_fn);
143        }
144        let name = self.name.clone();
145        let variadic = matches!(self.args.last(), Some(t) if t == "...");
146        let args = match variadic {
147            true => &self.args[..self.args.len() - 1],
148            false => &self.args[..],
149        }
150        .iter()
151        .map(|ty| sig_data_type(ty))
152        .collect_vec();
153        let ret = sig_data_type(&self.ret);
154
155        let pb_type = format_ident!("{}", utils::to_camel_case(&name));
156        let ctor_name = format_ident!("{}", self.ident_name());
157        let build_fn = if build_fn {
158            let name = format_ident!("{}", user_fn.name);
159            quote! { #name }
160        } else if self.rewritten {
161            quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
162        } else {
163            // This is the core logic for `#[function]`
164            self.generate_build_scalar_function(user_fn, true)?
165        };
166        let type_infer_fn = self.generate_type_infer_fn()?;
167        let deprecated = self.deprecated;
168
169        Ok(quote! {
170            #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
171            fn #ctor_name() -> risingwave_expr::sig::FuncSign {
172                use risingwave_common::types::{DataType, DataTypeName};
173                use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
174
175                FuncSign {
176                    name: risingwave_pb::expr::expr_node::Type::#pb_type.into(),
177                    inputs_type: vec![#(#args),*],
178                    variadic: #variadic,
179                    ret_type: #ret,
180                    build: FuncBuilder::Scalar(#build_fn),
181                    type_infer: #type_infer_fn,
182                    deprecated: #deprecated,
183                }
184            }
185        })
186    }
187
188    /// Generate a build function for the scalar function.
189    ///
190    /// If `optimize_const` is true, the function will be optimized for constant arguments,
191    /// and fallback to the general version if any argument is not constant.
192    fn generate_build_scalar_function(
193        &self,
194        user_fn: &UserFunctionAttr,
195        optimize_const: bool,
196    ) -> Result<TokenStream2> {
197        let variadic = matches!(self.args.last(), Some(t) if t == "...");
198        let num_args = self.args.len() - if variadic { 1 } else { 0 };
199        let fn_name = format_ident!("{}", user_fn.name);
200        let struct_name = match optimize_const {
201            true => format_ident!("{}OptimizeConst", utils::to_camel_case(&self.ident_name())),
202            false => format_ident!("{}", utils::to_camel_case(&self.ident_name())),
203        };
204
205        // we divide all arguments into two groups: prebuilt and non-prebuilt.
206        // prebuilt arguments are collected from the "prebuild" field.
207        // let's say we have a function with 3 arguments: [0, 1, 2]
208        // and the prebuild field contains "$1".
209        // then we have:
210        //     prebuilt_indices = [1]
211        //     non_prebuilt_indices = [0, 2]
212        //
213        // if the const argument optimization is enabled, prebuilt arguments are
214        // evaluated at build time, thus the children only contain non-prebuilt arguments:
215        //     children_indices = [0, 2]
216        // otherwise, the children contain all arguments:
217        //     children_indices = [0, 1, 2]
218
219        let prebuilt_indices = match &self.prebuild {
220            Some(s) => (0..num_args)
221                .filter(|i| s.contains(&format!("${i}")))
222                .collect_vec(),
223            None => vec![],
224        };
225        let non_prebuilt_indices = match &self.prebuild {
226            Some(s) => (0..num_args)
227                .filter(|i| !s.contains(&format!("${i}")))
228                .collect_vec(),
229            _ => (0..num_args).collect_vec(),
230        };
231        let children_indices = match optimize_const {
232            #[allow(clippy::redundant_clone)] // false-positive
233            true => non_prebuilt_indices.clone(),
234            false => (0..num_args).collect_vec(),
235        };
236
237        /// Return a list of identifiers with the given prefix and indices.
238        fn idents(prefix: &str, indices: &[usize]) -> Vec<Ident> {
239            indices
240                .iter()
241                .map(|i| format_ident!("{prefix}{i}"))
242                .collect()
243        }
244        let inputs = idents("i", &children_indices);
245        let prebuilt_inputs = idents("i", &prebuilt_indices);
246        let non_prebuilt_inputs = idents("i", &non_prebuilt_indices);
247        let array_refs = idents("array", &children_indices);
248        let arrays = idents("a", &children_indices);
249        let datums = idents("v", &children_indices);
250        let arg_arrays = children_indices
251            .iter()
252            .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
253        let arg_types = children_indices.iter().map(|i| {
254            types::ref_type(&self.args[*i])
255                .parse::<TokenStream2>()
256                .unwrap()
257        });
258        let annotation: TokenStream2 = match user_fn.core_return_type.as_str() {
259            // add type annotation for functions that return generic types
260            "T" | "T1" | "T2" | "T3" => format!(": Option<{}>", types::owned_type(&self.ret))
261                .parse()
262                .unwrap(),
263            _ => quote! {},
264        };
265        let ret_array_type = format_ident!("{}", types::array_type(&self.ret));
266        let builder_type = format_ident!("{}Builder", types::array_type(&self.ret));
267        let prebuilt_arg_type = match &self.prebuild {
268            Some(s) if optimize_const => s.split("::").next().unwrap().parse().unwrap(),
269            _ => quote! { () },
270        };
271        let prebuilt_arg_value = match &self.prebuild {
272            // example:
273            // prebuild = "RegexContext::new($1)"
274            // return = "RegexContext::new(i1)"
275            Some(s) => s
276                .replace('$', "i")
277                .parse()
278                .expect("invalid prebuild syntax"),
279            None => quote! { () },
280        };
281        let prebuild_const = if self.prebuild.is_some() && optimize_const {
282            let build_general = self.generate_build_scalar_function(user_fn, false)?;
283            quote! {{
284                let build_general = #build_general;
285                #(
286                    // try to evaluate constant for prebuilt arguments
287                    let #prebuilt_inputs = match children[#prebuilt_indices].eval_const() {
288                        Ok(s) => s,
289                        // prebuilt argument is not constant, fallback to general
290                        Err(_) => return build_general(return_type, children),
291                    };
292                    // get reference to the constant value
293                    let #prebuilt_inputs = match &#prebuilt_inputs {
294                        Some(s) => s.as_scalar_ref_impl().try_into()?,
295                        // the function should always return null if any const argument is null
296                        None => return Ok(Box::new(risingwave_expr::expr::LiteralExpression::new(
297                            return_type,
298                            None,
299                        ))),
300                    };
301                )*
302                #prebuilt_arg_value
303            }}
304        } else {
305            quote! { () }
306        };
307
308        // ensure the number of children matches the number of arguments
309        let check_children = match variadic {
310            true => quote! { risingwave_expr::ensure!(children.len() >= #num_args); },
311            false => quote! { risingwave_expr::ensure!(children.len() == #num_args); },
312        };
313
314        // evaluate variadic arguments in `eval`
315        let eval_variadic = variadic.then(|| {
316            quote! {
317                let mut columns = Vec::with_capacity(self.children.len() - #num_args);
318                for child in &self.children[#num_args..] {
319                    columns.push(child.eval(input).await?);
320                }
321                let variadic_input = DataChunk::new(columns, input.visibility().clone());
322            }
323        });
324        // evaluate variadic arguments in `eval_row`
325        let eval_row_variadic = variadic.then(|| {
326            quote! {
327                let mut row = Vec::with_capacity(self.children.len() - #num_args);
328                for child in &self.children[#num_args..] {
329                    row.push(child.eval_row(input).await?);
330                }
331                let variadic_row = OwnedRow::new(row);
332            }
333        });
334
335        let generic = (self.ret == "boolean" && user_fn.generic == 3).then(|| {
336            // XXX: for generic compare functions, we need to specify the compatible type
337            let compatible_type = types::ref_type(types::min_compatible_type(&self.args))
338                .parse::<TokenStream2>()
339                .unwrap();
340            quote! { ::<_, _, #compatible_type> }
341        });
342        let prebuilt_arg = match (&self.prebuild, optimize_const) {
343            // use the prebuilt argument
344            (Some(_), true) => quote! { &self.prebuilt_arg, },
345            // build the argument on site
346            (Some(_), false) => quote! { &#prebuilt_arg_value, },
347            // no prebuilt argument
348            (None, _) => quote! {},
349        };
350        let variadic_args = variadic.then(|| quote! { &variadic_row, });
351        let context = user_fn.context.then(|| quote! { &self.context, });
352        let writer = user_fn.write.then(|| quote! { &mut writer, });
353        let await_ = user_fn.async_.then(|| quote! { .await });
354
355        let record_error = {
356            // Uniform arguments into `DatumRef`.
357            #[allow(clippy::disallowed_methods)] // allow zip
358            let inputs_args = inputs
359                .iter()
360                .zip(user_fn.args_option.iter())
361                .map(|(input, opt)| {
362                    if *opt {
363                        quote! { #input.map(|s| ScalarRefImpl::from(s)) }
364                    } else {
365                        quote! { Some(ScalarRefImpl::from(#input)) }
366                    }
367                });
368            let inputs_args = quote! {
369                let args: &[DatumRef<'_>] = &[#(#inputs_args),*];
370                let args = args.iter().copied();
371            };
372            let var_args = variadic.then(|| {
373                quote! {
374                    let args = args.chain(variadic_row.iter());
375                }
376            });
377
378            quote! {
379                #inputs_args
380                #var_args
381                errors.push(ExprError::function(
382                    stringify!(#fn_name),
383                    args,
384                    e,
385                ));
386            }
387        };
388
389        // call the user defined function
390        // inputs: [ Option<impl ScalarRef> ]
391        let mut output = quote! { #fn_name #generic(
392            #(#non_prebuilt_inputs,)*
393            #variadic_args
394            #prebuilt_arg
395            #context
396            #writer
397        ) #await_ };
398        // handle error if the function returns `Result`
399        // wrap a `Some` if the function doesn't return `Option`
400        output = match user_fn.return_type_kind {
401            // XXX: we don't support void type yet. return null::int for now.
402            _ if self.ret == "void" => quote! { { #output; Option::<i32>::None } },
403            ReturnTypeKind::T => quote! { Some(#output) },
404            ReturnTypeKind::Option => output,
405            ReturnTypeKind::Result => quote! {
406                match #output {
407                    Ok(x) => Some(x),
408                    Err(e) => {
409                        #record_error
410                        None
411                    }
412                }
413            },
414            ReturnTypeKind::ResultOption => quote! {
415                match #output {
416                    Ok(x) => x,
417                    Err(e) => {
418                        #record_error
419                        None
420                    }
421                }
422            },
423        };
424        // if user function accepts non-option arguments, we assume the function
425        // returns null on null input, so we need to unwrap the inputs before calling.
426        if self.prebuild.is_some() {
427            output = quote! {
428                match (#(#inputs,)*) {
429                    (#(Some(#inputs),)*) => #output,
430                    _ => None,
431                }
432            };
433        } else {
434            #[allow(clippy::disallowed_methods)] // allow zip
435            let some_inputs = inputs
436                .iter()
437                .zip(user_fn.args_option.iter())
438                .map(|(input, opt)| {
439                    if *opt {
440                        quote! { #input }
441                    } else {
442                        quote! { Some(#input) }
443                    }
444                });
445            output = quote! {
446                match (#(#inputs,)*) {
447                    (#(#some_inputs,)*) => #output,
448                    _ => None,
449                }
450            };
451        };
452        // now the `output` is: Option<impl ScalarRef or Scalar>
453        let append_output = match user_fn.write {
454            true => quote! {{
455                let mut writer = builder.writer().begin();
456                if #output.is_some() {
457                    writer.finish();
458                } else {
459                    drop(writer);
460                    builder.append_null();
461                }
462            }},
463            false if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
464                builder.append(#output.as_ref().map(|s| s.as_ref()));
465            },
466            false => quote! {
467                let output #annotation = #output;
468                builder.append(output.as_ref().map(|s| s.as_scalar_ref()));
469            },
470        };
471        // the output expression in `eval_row`
472        let row_output = match user_fn.write {
473            true => quote! {{
474                let mut writer = String::new();
475                #output.map(|_| writer.into())
476            }},
477            false if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
478                #output.map(|s| s.as_ref().into())
479            },
480            false => quote! {{
481                let output #annotation = #output;
482                output.map(|s| s.into())
483            }},
484        };
485        // the main body in `eval`
486        let eval = if let Some(batch_fn) = &self.batch_fn {
487            assert!(
488                !variadic,
489                "customized batch function is not supported for variadic functions"
490            );
491            // user defined batch function
492            let fn_name = format_ident!("{}", batch_fn);
493            quote! {
494                let c = #fn_name(#(#arrays),*);
495                Arc::new(c.into())
496            }
497        } else if (types::is_primitive(&self.ret) || self.ret == "boolean")
498            && user_fn.is_pure()
499            && !variadic
500            && self.prebuild.is_none()
501        {
502            // SIMD optimization for primitive types
503            match self.args.len() {
504                0 => quote! {
505                    let c = #ret_array_type::from_iter_bitmap(
506                        std::iter::repeat_with(|| #fn_name()).take(input.capacity())
507                        Bitmap::ones(input.capacity()),
508                    );
509                    Arc::new(c.into())
510                },
511                1 => quote! {
512                    let c = #ret_array_type::from_iter_bitmap(
513                        a0.raw_iter().map(|a| #fn_name(a)),
514                        a0.null_bitmap().clone()
515                    );
516                    Arc::new(c.into())
517                },
518                2 => quote! {
519                    // allow using `zip` for performance
520                    #[allow(clippy::disallowed_methods)]
521                    let c = #ret_array_type::from_iter_bitmap(
522                        a0.raw_iter()
523                            .zip(a1.raw_iter())
524                            .map(|(a, b)| #fn_name #generic(a, b)),
525                        a0.null_bitmap() & a1.null_bitmap(),
526                    );
527                    Arc::new(c.into())
528                },
529                n => todo!("SIMD optimization for {n} arguments"),
530            }
531        } else {
532            // no optimization
533            let let_variadic = variadic.then(|| {
534                quote! {
535                    let variadic_row = variadic_input.row_at_unchecked_vis(i);
536                }
537            });
538            quote! {
539                let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone());
540
541                if input.is_compacted() {
542                    for i in 0..input.capacity() {
543                        #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
544                        #let_variadic
545                        #append_output
546                    }
547                } else {
548                    for i in 0..input.capacity() {
549                        if unsafe { !input.visibility().is_set_unchecked(i) } {
550                            builder.append_null();
551                            continue;
552                        }
553                        #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
554                        #let_variadic
555                        #append_output
556                    }
557                }
558                Arc::new(builder.finish().into())
559            }
560        };
561
562        Ok(quote! {
563            |return_type: DataType, children: Vec<risingwave_expr::expr::BoxedExpression>|
564                -> risingwave_expr::Result<risingwave_expr::expr::BoxedExpression>
565            {
566                use std::sync::Arc;
567                use risingwave_common::array::*;
568                use risingwave_common::types::*;
569                use risingwave_common::bitmap::Bitmap;
570                use risingwave_common::row::OwnedRow;
571                use risingwave_common::util::iter_util::ZipEqFast;
572
573                use risingwave_expr::expr::{Context, BoxedExpression};
574                use risingwave_expr::{ExprError, Result};
575                use risingwave_expr::codegen::*;
576
577                #check_children
578                let prebuilt_arg = #prebuild_const;
579                let context = Context {
580                    return_type,
581                    arg_types: children.iter().map(|c| c.return_type()).collect(),
582                    variadic: #variadic,
583                };
584
585                #[derive(Debug)]
586                struct #struct_name {
587                    context: Context,
588                    children: Vec<BoxedExpression>,
589                    prebuilt_arg: #prebuilt_arg_type,
590                }
591                #[async_trait]
592                impl risingwave_expr::expr::Expression for #struct_name {
593                    fn return_type(&self) -> DataType {
594                        self.context.return_type.clone()
595                    }
596                    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
597                        #(
598                            let #array_refs = self.children[#children_indices].eval(input).await?;
599                            let #arrays: &#arg_arrays = #array_refs.as_ref().into();
600                        )*
601                        #eval_variadic
602                        let mut errors = vec![];
603                        let array = { #eval };
604                        if errors.is_empty() {
605                            Ok(array)
606                        } else {
607                            Err(ExprError::Multiple(array, errors.into()))
608                        }
609                    }
610                    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
611                        #(
612                            let #datums = self.children[#children_indices].eval_row(input).await?;
613                            let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap());
614                        )*
615                        #eval_row_variadic
616                        let mut errors: Vec<ExprError> = vec![];
617                        let output = #row_output;
618                        if let Some(err) = errors.into_iter().next() {
619                            Err(err.into())
620                        } else {
621                            Ok(output)
622                        }
623                    }
624                }
625
626                Ok(Box::new(#struct_name {
627                    context,
628                    children,
629                    prebuilt_arg,
630                }))
631            }
632        })
633    }
634
635    /// Generate a descriptor of the aggregate function.
636    ///
637    /// The types of arguments and return value should not contain wildcard.
638    /// `user_fn` could be either `fn` or `impl`.
639    /// If `build_fn` is true, `user_fn` must be a `fn` that builds the aggregate function.
640    pub fn generate_aggregate_descriptor(
641        &self,
642        user_fn: &AggregateFnOrImpl,
643        build_fn: bool,
644    ) -> Result<TokenStream2> {
645        let name = self.name.clone();
646
647        let mut args = Vec::with_capacity(self.args.len());
648        for ty in &self.args {
649            args.push(sig_data_type(ty));
650        }
651        let ret = sig_data_type(&self.ret);
652        let state_type = match &self.state {
653            Some(ty) if ty != "ref" => {
654                let ty = data_type(ty);
655                quote! { Some(#ty) }
656            }
657            _ => quote! { None },
658        };
659        let append_only = match build_fn {
660            false => !user_fn.has_retract(),
661            true => self.append_only,
662        };
663
664        let pb_kind = format_ident!("{}", utils::to_camel_case(&name));
665        let ctor_name = match append_only {
666            false => format_ident!("{}", self.ident_name()),
667            true => format_ident!("{}_append_only", self.ident_name()),
668        };
669        let build_fn = if build_fn {
670            let name = format_ident!("{}", user_fn.as_fn().name);
671            quote! { #name }
672        } else if self.rewritten {
673            quote! { |_| Err(ExprError::UnsupportedFunction(#name.into())) }
674        } else {
675            self.generate_agg_build_fn(user_fn)?
676        };
677        let build_retractable = match append_only {
678            true => quote! { None },
679            false => quote! { Some(#build_fn) },
680        };
681        let build_append_only = match append_only {
682            false => quote! { None },
683            true => quote! { Some(#build_fn) },
684        };
685        let retractable_state_type = match append_only {
686            true => quote! { None },
687            false => state_type.clone(),
688        };
689        let append_only_state_type = match append_only {
690            false => quote! { None },
691            true => state_type,
692        };
693        let type_infer_fn = self.generate_type_infer_fn()?;
694        let deprecated = self.deprecated;
695
696        Ok(quote! {
697            #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
698            fn #ctor_name() -> risingwave_expr::sig::FuncSign {
699                use risingwave_common::types::{DataType, DataTypeName};
700                use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
701
702                FuncSign {
703                    name: risingwave_pb::expr::agg_call::PbKind::#pb_kind.into(),
704                    inputs_type: vec![#(#args),*],
705                    variadic: false,
706                    ret_type: #ret,
707                    build: FuncBuilder::Aggregate {
708                        retractable: #build_retractable,
709                        append_only: #build_append_only,
710                        retractable_state_type: #retractable_state_type,
711                        append_only_state_type: #append_only_state_type,
712                    },
713                    type_infer: #type_infer_fn,
714                    deprecated: #deprecated,
715                }
716            }
717        })
718    }
719
720    /// Generate build function for aggregate function.
721    fn generate_agg_build_fn(&self, user_fn: &AggregateFnOrImpl) -> Result<TokenStream2> {
722        // If the first argument of the aggregate function is of type `&mut T`,
723        // we assume it is a user defined state type.
724        let custom_state = user_fn.accumulate().first_mut_ref_arg.as_ref();
725        let state_type: TokenStream2 = match (custom_state, &self.state) {
726            (Some(s), _) => s.parse().unwrap(),
727            (_, Some(state)) if state == "ref" => types::ref_type(&self.ret).parse().unwrap(),
728            (_, Some(state)) if state != "ref" => types::owned_type(state).parse().unwrap(),
729            _ => types::owned_type(&self.ret).parse().unwrap(),
730        };
731        let let_arrays = self
732            .args
733            .iter()
734            .enumerate()
735            .map(|(i, arg)| {
736                let array = format_ident!("a{i}");
737                let array_type: TokenStream2 = types::array_type(arg).parse().unwrap();
738                quote! {
739                    let #array: &#array_type = input.column_at(#i).as_ref().into();
740                }
741            })
742            .collect_vec();
743        let let_values = (0..self.args.len())
744            .map(|i| {
745                let v = format_ident!("v{i}");
746                let a = format_ident!("a{i}");
747                quote! { let #v = unsafe { #a.value_at_unchecked(row_id) }; }
748            })
749            .collect_vec();
750        let downcast_state = if custom_state.is_some() {
751            quote! { let mut state: &mut #state_type = state0.downcast_mut(); }
752        } else if let Some(s) = &self.state
753            && s == "ref"
754        {
755            quote! { let mut state: Option<#state_type> = state0.as_datum_mut().as_ref().map(|x| x.as_scalar_ref_impl().try_into().unwrap()); }
756        } else {
757            quote! { let mut state: Option<#state_type> = state0.as_datum_mut().take().map(|s| s.try_into().unwrap()); }
758        };
759        let restore_state = if custom_state.is_some() {
760            quote! {}
761        } else if let Some(s) = &self.state
762            && s == "ref"
763        {
764            quote! { *state0.as_datum_mut() = state.map(|x| x.to_owned_scalar().into()); }
765        } else {
766            quote! { *state0.as_datum_mut() = state.map(|s| s.into()); }
767        };
768        let create_state = if custom_state.is_some() {
769            quote! {
770                fn create_state(&self) -> Result<AggregateState> {
771                    Ok(AggregateState::Any(Box::<#state_type>::default()))
772                }
773            }
774        } else if let Some(state) = &self.init_state {
775            let state: TokenStream2 = state.parse().unwrap();
776            quote! {
777                fn create_state(&self) -> Result<AggregateState> {
778                    Ok(AggregateState::Datum(Some(#state.into())))
779                }
780            }
781        } else {
782            // by default: `AggregateState::Datum(None)`
783            quote! {}
784        };
785        let args = (0..self.args.len()).map(|i| format_ident!("v{i}"));
786        let args = quote! { #(#args,)* };
787        let panic_on_retract = {
788            let msg = format!(
789                "attempt to retract on aggregate function {}, but it is append-only",
790                self.name
791            );
792            quote! { assert_eq!(op, Op::Insert, #msg); }
793        };
794        let mut next_state = match user_fn {
795            AggregateFnOrImpl::Fn(f) => {
796                let context = f.context.then(|| quote! { &self.context, });
797                let fn_name = format_ident!("{}", f.name);
798                match f.retract {
799                    true => {
800                        quote! { #fn_name(state, #args matches!(op, Op::Delete | Op::UpdateDelete) #context) }
801                    }
802                    false => quote! {{
803                        #panic_on_retract
804                        #fn_name(state, #args #context)
805                    }},
806                }
807            }
808            AggregateFnOrImpl::Impl(i) => {
809                let retract = match i.retract {
810                    Some(_) => quote! { self.function.retract(state, #args) },
811                    None => panic_on_retract,
812                };
813                quote! {
814                    if matches!(op, Op::Delete | Op::UpdateDelete) {
815                        #retract
816                    } else {
817                        self.function.accumulate(state, #args)
818                    }
819                }
820            }
821        };
822        next_state = match user_fn.accumulate().return_type_kind {
823            ReturnTypeKind::T => quote! { Some(#next_state) },
824            ReturnTypeKind::Option => next_state,
825            ReturnTypeKind::Result => quote! { Some(#next_state?) },
826            ReturnTypeKind::ResultOption => quote! { #next_state? },
827        };
828        if user_fn.accumulate().args_option.iter().all(|b| !b) {
829            match self.args.len() {
830                0 => {
831                    next_state = quote! {
832                        match state {
833                            Some(state) => #next_state,
834                            None => state,
835                        }
836                    };
837                }
838                1 => {
839                    let first_state = if self.init_state.is_some() {
840                        // for count, the state will never be None
841                        quote! { unreachable!() }
842                    } else if let Some(s) = &self.state
843                        && s == "ref"
844                    {
845                        // for min/max/first/last, the state is the first value
846                        quote! { Some(v0) }
847                    } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
848                        && impl_.create_state.is_some()
849                    {
850                        // use user-defined create_state function
851                        quote! {{
852                            let state = self.function.create_state();
853                            #next_state
854                        }}
855                    } else {
856                        quote! {{
857                            let state = #state_type::default();
858                            #next_state
859                        }}
860                    };
861                    next_state = quote! {
862                        match (state, v0) {
863                            (Some(state), Some(v0)) => #next_state,
864                            (None, Some(v0)) => #first_state,
865                            (state, None) => state,
866                        }
867                    };
868                }
869                _ => todo!("multiple arguments are not supported for non-option function"),
870            }
871        }
872        let update_state = if custom_state.is_some() {
873            quote! { _ = #next_state; }
874        } else {
875            quote! { state = #next_state; }
876        };
877        let get_result = if custom_state.is_some() {
878            quote! { Ok(state.downcast_ref::<#state_type>().into()) }
879        } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
880            && impl_.finalize.is_some()
881        {
882            quote! {
883                let state = match state.as_datum() {
884                    Some(s) => s.as_scalar_ref_impl().try_into().unwrap(),
885                    None => return Ok(None),
886                };
887                Ok(Some(self.function.finalize(state).into()))
888            }
889        } else {
890            quote! { Ok(state.as_datum().clone()) }
891        };
892        let function_field = match user_fn {
893            AggregateFnOrImpl::Fn(_) => quote! {},
894            AggregateFnOrImpl::Impl(i) => {
895                let struct_name = format_ident!("{}", i.struct_name);
896                let generic = self.generic.as_ref().map(|g| {
897                    let g = format_ident!("{g}");
898                    quote! { <#g> }
899                });
900                quote! { function: #struct_name #generic, }
901            }
902        };
903        let function_new = match user_fn {
904            AggregateFnOrImpl::Fn(_) => quote! {},
905            AggregateFnOrImpl::Impl(i) => {
906                let struct_name = format_ident!("{}", i.struct_name);
907                let generic = self.generic.as_ref().map(|g| {
908                    let g = format_ident!("{g}");
909                    quote! { ::<#g> }
910                });
911                quote! { function: #struct_name #generic :: default(), }
912            }
913        };
914
915        Ok(quote! {
916            |agg| {
917                use std::collections::HashSet;
918                use std::ops::Range;
919                use risingwave_common::array::*;
920                use risingwave_common::types::*;
921                use risingwave_common::bail;
922                use risingwave_common::bitmap::Bitmap;
923                use risingwave_common_estimate_size::EstimateSize;
924
925                use risingwave_expr::expr::Context;
926                use risingwave_expr::Result;
927                use risingwave_expr::aggregate::AggregateState;
928                use risingwave_expr::codegen::async_trait;
929
930                let context = Context {
931                    return_type: agg.return_type.clone(),
932                    arg_types: agg.args.arg_types().to_owned(),
933                    variadic: false,
934                };
935
936                struct Agg {
937                    context: Context,
938                    #function_field
939                }
940
941                #[async_trait]
942                impl risingwave_expr::aggregate::AggregateFunction for Agg {
943                    fn return_type(&self) -> DataType {
944                        self.context.return_type.clone()
945                    }
946
947                    #create_state
948
949                    async fn update(&self, state0: &mut AggregateState, input: &StreamChunk) -> Result<()> {
950                        #(#let_arrays)*
951                        #downcast_state
952                        for row_id in input.visibility().iter_ones() {
953                            let op = unsafe { *input.ops().get_unchecked(row_id) };
954                            #(#let_values)*
955                            #update_state
956                        }
957                        #restore_state
958                        Ok(())
959                    }
960
961                    async fn update_range(&self, state0: &mut AggregateState, input: &StreamChunk, range: Range<usize>) -> Result<()> {
962                        assert!(range.end <= input.capacity());
963                        #(#let_arrays)*
964                        #downcast_state
965                        if input.is_compacted() {
966                            for row_id in range {
967                                let op = unsafe { *input.ops().get_unchecked(row_id) };
968                                #(#let_values)*
969                                #update_state
970                            }
971                        } else {
972                            for row_id in input.visibility().iter_ones() {
973                                if row_id < range.start {
974                                    continue;
975                                } else if row_id >= range.end {
976                                    break;
977                                }
978                                let op = unsafe { *input.ops().get_unchecked(row_id) };
979                                #(#let_values)*
980                                #update_state
981                            }
982                        }
983                        #restore_state
984                        Ok(())
985                    }
986
987                    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
988                        #get_result
989                    }
990                }
991
992                Ok(Box::new(Agg {
993                    context,
994                    #function_new
995                }))
996            }
997        })
998    }
999
1000    /// Generate a descriptor of the table function.
1001    ///
1002    /// The types of arguments and return value should not contain wildcard.
1003    fn generate_table_function_descriptor(
1004        &self,
1005        user_fn: &UserFunctionAttr,
1006        build_fn: bool,
1007    ) -> Result<TokenStream2> {
1008        let name = self.name.clone();
1009        let mut args = Vec::with_capacity(self.args.len());
1010        for ty in &self.args {
1011            args.push(sig_data_type(ty));
1012        }
1013        let ret = sig_data_type(&self.ret);
1014
1015        let pb_type = format_ident!("{}", utils::to_camel_case(&name));
1016        let ctor_name = format_ident!("{}", self.ident_name());
1017        let build_fn = if build_fn {
1018            let name = format_ident!("{}", user_fn.name);
1019            quote! { #name }
1020        } else if self.rewritten {
1021            quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
1022        } else {
1023            self.generate_build_table_function(user_fn)?
1024        };
1025        let type_infer_fn = self.generate_type_infer_fn()?;
1026        let deprecated = self.deprecated;
1027
1028        Ok(quote! {
1029            #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
1030            fn #ctor_name() -> risingwave_expr::sig::FuncSign {
1031                use risingwave_common::types::{DataType, DataTypeName};
1032                use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
1033
1034                FuncSign {
1035                    name: risingwave_pb::expr::table_function::Type::#pb_type.into(),
1036                    inputs_type: vec![#(#args),*],
1037                    variadic: false,
1038                    ret_type: #ret,
1039                    build: FuncBuilder::Table(#build_fn),
1040                    type_infer: #type_infer_fn,
1041                    deprecated: #deprecated,
1042                }
1043            }
1044        })
1045    }
1046
1047    fn generate_build_table_function(&self, user_fn: &UserFunctionAttr) -> Result<TokenStream2> {
1048        let num_args = self.args.len();
1049        let return_types = output_types(&self.ret);
1050        let fn_name = format_ident!("{}", user_fn.name);
1051        let struct_name = format_ident!("{}", utils::to_camel_case(&self.ident_name()));
1052        let arg_ids = (0..num_args)
1053            .filter(|i| match &self.prebuild {
1054                Some(s) => !s.contains(&format!("${i}")),
1055                None => true,
1056            })
1057            .collect_vec();
1058        let const_ids = (0..num_args).filter(|i| match &self.prebuild {
1059            Some(s) => s.contains(&format!("${i}")),
1060            None => false,
1061        });
1062        let inputs: Vec<_> = arg_ids.iter().map(|i| format_ident!("i{i}")).collect();
1063        let all_child: Vec<_> = (0..num_args).map(|i| format_ident!("child{i}")).collect();
1064        let const_child: Vec<_> = const_ids.map(|i| format_ident!("child{i}")).collect();
1065        let child: Vec<_> = arg_ids.iter().map(|i| format_ident!("child{i}")).collect();
1066        let array_refs: Vec<_> = arg_ids.iter().map(|i| format_ident!("array{i}")).collect();
1067        let arrays: Vec<_> = arg_ids.iter().map(|i| format_ident!("a{i}")).collect();
1068        let arg_arrays = arg_ids
1069            .iter()
1070            .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
1071        let outputs = (0..return_types.len())
1072            .map(|i| format_ident!("o{i}"))
1073            .collect_vec();
1074        let builders = (0..return_types.len())
1075            .map(|i| format_ident!("builder{i}"))
1076            .collect_vec();
1077        let builder_types = return_types
1078            .iter()
1079            .map(|ty| format_ident!("{}Builder", types::array_type(ty)))
1080            .collect_vec();
1081        let return_types = if return_types.len() == 1 {
1082            vec![quote! { self.context.return_type.clone() }]
1083        } else {
1084            (0..return_types.len())
1085                .map(|i| quote! { self.context.return_type.as_struct().types().nth(#i).unwrap().clone() })
1086                .collect()
1087        };
1088        #[allow(clippy::disallowed_methods)]
1089        let optioned_outputs = user_fn
1090            .core_return_type
1091            .split(',')
1092            .map(|t| t.contains("Option"))
1093            // example: "(Option<&str>, i32)" => [true, false]
1094            .zip(&outputs)
1095            .map(|(optional, o)| match optional {
1096                false => quote! { Some(#o.as_scalar_ref()) },
1097                true => quote! { #o.map(|o| o.as_scalar_ref()) },
1098            })
1099            .collect_vec();
1100        let build_value_array = if return_types.len() == 1 {
1101            quote! { let [value_array] = value_arrays; }
1102        } else {
1103            quote! {
1104                let value_array = StructArray::new(
1105                    self.context.return_type.as_struct().clone(),
1106                    value_arrays.to_vec(),
1107                    Bitmap::ones(len),
1108                ).into_ref();
1109            }
1110        };
1111        let context = user_fn.context.then(|| quote! { &self.context, });
1112        let prebuilt_arg = match &self.prebuild {
1113            Some(_) => quote! { &self.prebuilt_arg, },
1114            None => quote! {},
1115        };
1116        let prebuilt_arg_type = match &self.prebuild {
1117            Some(s) => s.split("::").next().unwrap().parse().unwrap(),
1118            None => quote! { () },
1119        };
1120        let prebuilt_arg_value = match &self.prebuild {
1121            Some(s) => s
1122                .replace('$', "child")
1123                .parse()
1124                .expect("invalid prebuild syntax"),
1125            None => quote! { () },
1126        };
1127        let iter = quote! { #fn_name(#(#inputs,)* #prebuilt_arg #context) };
1128        let mut iter = match user_fn.return_type_kind {
1129            ReturnTypeKind::T => quote! { #iter },
1130            ReturnTypeKind::Option => quote! { match #iter {
1131                Some(it) => it,
1132                None => continue,
1133            } },
1134            ReturnTypeKind::Result => quote! { match #iter {
1135                Ok(it) => it,
1136                Err(e) => {
1137                    index_builder.append(Some(i as i32));
1138                    #(#builders.append_null();)*
1139                    error_builder.append_display(Some(e.as_report()));
1140                    continue;
1141                }
1142            } },
1143            ReturnTypeKind::ResultOption => quote! { match #iter {
1144                Ok(Some(it)) => it,
1145                Ok(None) => continue,
1146                Err(e) => {
1147                    index_builder.append(Some(i as i32));
1148                    #(#builders.append_null();)*
1149                    error_builder.append_display(Some(e.as_report()));
1150                    continue;
1151                }
1152            } },
1153        };
1154        // if user function accepts non-option arguments, we assume the function
1155        // returns empty on null input, so we need to unwrap the inputs before calling.
1156        #[allow(clippy::disallowed_methods)] // allow zip
1157        let some_inputs = inputs
1158            .iter()
1159            .zip(user_fn.args_option.iter())
1160            .map(|(input, opt)| {
1161                if *opt {
1162                    quote! { #input }
1163                } else {
1164                    quote! { Some(#input) }
1165                }
1166            });
1167        iter = quote! {
1168            match (#(#inputs,)*) {
1169                (#(#some_inputs,)*) => #iter,
1170                _ => continue,
1171            }
1172        };
1173        let iterator_item_type = user_fn.iterator_item_kind.clone().ok_or_else(|| {
1174            Error::new(
1175                user_fn.return_type_span,
1176                "expect `impl Iterator` in return type",
1177            )
1178        })?;
1179        let append_output = match iterator_item_type {
1180            ReturnTypeKind::T => quote! {
1181                let (#(#outputs),*) = output;
1182                #(#builders.append(#optioned_outputs);)* error_builder.append_null();
1183            },
1184            ReturnTypeKind::Option => quote! { match output {
1185                Some((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1186                None => { #(#builders.append_null();)* error_builder.append_null(); }
1187            } },
1188            ReturnTypeKind::Result => quote! { match output {
1189                Ok((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1190                Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1191            } },
1192            ReturnTypeKind::ResultOption => quote! { match output {
1193                Ok(Some((#(#outputs),*))) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1194                Ok(None) => { #(#builders.append_null();)* error_builder.append_null(); }
1195                Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1196            } },
1197        };
1198
1199        Ok(quote! {
1200            |return_type, chunk_size, children| {
1201                use risingwave_common::array::*;
1202                use risingwave_common::types::*;
1203                use risingwave_common::bitmap::Bitmap;
1204                use risingwave_common::util::iter_util::ZipEqFast;
1205                use risingwave_expr::expr::{BoxedExpression, Context};
1206                use risingwave_expr::{Result, ExprError};
1207                use risingwave_expr::codegen::*;
1208
1209                risingwave_expr::ensure!(children.len() == #num_args);
1210
1211                let context = Context {
1212                    return_type: return_type.clone(),
1213                    arg_types: children.iter().map(|c| c.return_type()).collect(),
1214                    variadic: false,
1215                };
1216
1217                let mut iter = children.into_iter();
1218                #(let #all_child = iter.next().unwrap();)*
1219                #(
1220                    let #const_child = #const_child.eval_const()?;
1221                    let #const_child = match &#const_child {
1222                        Some(s) => s.as_scalar_ref_impl().try_into()?,
1223                        // the function should always return empty if any const argument is null
1224                        None => return Ok(risingwave_expr::table_function::empty(return_type)),
1225                    };
1226                )*
1227
1228                #[derive(Debug)]
1229                struct #struct_name {
1230                    context: Context,
1231                    chunk_size: usize,
1232                    #(#child: BoxedExpression,)*
1233                    prebuilt_arg: #prebuilt_arg_type,
1234                }
1235                #[async_trait]
1236                impl risingwave_expr::table_function::TableFunction for #struct_name {
1237                    fn return_type(&self) -> DataType {
1238                        self.context.return_type.clone()
1239                    }
1240                    async fn eval<'a>(&'a self, input: &'a DataChunk) -> BoxStream<'a, Result<DataChunk>> {
1241                        self.eval_inner(input)
1242                    }
1243                }
1244                impl #struct_name {
1245                    #[try_stream(boxed, ok = DataChunk, error = ExprError)]
1246                    async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
1247                        #(
1248                        let #array_refs = self.#child.eval(input).await?;
1249                        let #arrays: &#arg_arrays = #array_refs.as_ref().into();
1250                        )*
1251
1252                        let mut index_builder = I32ArrayBuilder::new(self.chunk_size);
1253                        #(let mut #builders = #builder_types::with_type(self.chunk_size, #return_types);)*
1254                        let mut error_builder = Utf8ArrayBuilder::new(self.chunk_size);
1255
1256                        for i in 0..input.capacity() {
1257                            if unsafe { !input.visibility().is_set_unchecked(i) } {
1258                                continue;
1259                            }
1260                            #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
1261                            for output in #iter {
1262                                index_builder.append(Some(i as i32));
1263                                #append_output
1264
1265                                if index_builder.len() == self.chunk_size {
1266                                    let len = index_builder.len();
1267                                    let index_array = std::mem::replace(&mut index_builder, I32ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1268                                    let value_arrays = [#(std::mem::replace(&mut #builders, #builder_types::with_type(self.chunk_size, #return_types)).finish().into_ref()),*];
1269                                    #build_value_array
1270                                    let error_array = std::mem::replace(&mut error_builder, Utf8ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1271                                    if error_array.null_bitmap().any() {
1272                                        yield DataChunk::new(vec![index_array, value_array, error_array], self.chunk_size);
1273                                    } else {
1274                                        yield DataChunk::new(vec![index_array, value_array], self.chunk_size);
1275                                    }
1276                                }
1277                            }
1278                        }
1279
1280                        if index_builder.len() > 0 {
1281                            let len = index_builder.len();
1282                            let index_array = index_builder.finish().into_ref();
1283                            let value_arrays = [#(#builders.finish().into_ref()),*];
1284                            #build_value_array
1285                            let error_array = error_builder.finish().into_ref();
1286                            if error_array.null_bitmap().any() {
1287                                yield DataChunk::new(vec![index_array, value_array, error_array], len);
1288                            } else {
1289                                yield DataChunk::new(vec![index_array, value_array], len);
1290                            }
1291                        }
1292                    }
1293                }
1294
1295                Ok(Box::new(#struct_name {
1296                    context,
1297                    chunk_size,
1298                    #(#child,)*
1299                    prebuilt_arg: #prebuilt_arg_value,
1300                }))
1301            }
1302        })
1303    }
1304}
1305
1306fn sig_data_type(ty: &str) -> TokenStream2 {
1307    match ty {
1308        "any" => quote! { SigDataType::Any },
1309        "anyarray" => quote! { SigDataType::AnyArray },
1310        "anymap" => quote! { SigDataType::AnyMap },
1311        "struct" => quote! { SigDataType::AnyStruct },
1312        _ if ty.starts_with("struct") && ty.contains("any") => quote! { SigDataType::AnyStruct },
1313        _ => {
1314            let datatype = data_type(ty);
1315            quote! { SigDataType::Exact(#datatype) }
1316        }
1317    }
1318}
1319
1320fn data_type(ty: &str) -> TokenStream2 {
1321    if let Some(ty) = ty.strip_suffix("[]") {
1322        let inner_type = data_type(ty);
1323        return quote! { DataType::List(Box::new(#inner_type)) };
1324    }
1325    if ty.starts_with("struct<") {
1326        return quote! { DataType::Struct(#ty.parse().expect("invalid struct type")) };
1327    }
1328    let variant = format_ident!("{}", types::data_type(ty));
1329    // TODO: enable the check
1330    // assert!(
1331    //     !matches!(ty, "any" | "anyarray" | "anymap" | "struct"),
1332    //     "{ty}, {variant}"
1333    // );
1334
1335    quote! { DataType::#variant }
1336}
1337
1338/// Extract multiple output types.
1339///
1340/// ```ignore
1341/// output_types("int4") -> ["int4"]
1342/// output_types("struct<key varchar, value jsonb>") -> ["varchar", "jsonb"]
1343/// ```
1344fn output_types(ty: &str) -> Vec<&str> {
1345    if let Some(s) = ty.strip_prefix("struct<")
1346        && let Some(args) = s.strip_suffix('>')
1347    {
1348        args.split(',')
1349            .map(|s| s.split_whitespace().nth(1).unwrap())
1350            .collect()
1351    } else {
1352        vec![ty]
1353    }
1354}