risingwave_expr_macro/
gen.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Generate code for the functions.
16
17use itertools::Itertools;
18use proc_macro2::{Ident, Span};
19use quote::{format_ident, quote};
20
21use super::*;
22
23impl FunctionAttr {
24    /// Expands the wildcard in function arguments or return type.
25    pub fn expand(&self) -> Vec<Self> {
26        // handle variadic argument
27        if self
28            .args
29            .last()
30            .is_some_and(|arg| arg.starts_with("variadic"))
31        {
32            // expand:  foo(a, b, variadic anyarray)
33            // to:      foo(a, b, ...)
34            //        + foo_variadic(a, b, anyarray)
35            let mut attrs = Vec::new();
36            attrs.extend(
37                FunctionAttr {
38                    args: {
39                        let mut args = self.args.clone();
40                        *args.last_mut().unwrap() = "...".to_owned();
41                        args
42                    },
43                    ..self.clone()
44                }
45                .expand(),
46            );
47            attrs.extend(
48                FunctionAttr {
49                    name: format!("{}_variadic", self.name),
50                    args: {
51                        let mut args = self.args.clone();
52                        let last = args.last_mut().unwrap();
53                        *last = last.strip_prefix("variadic ").unwrap().into();
54                        args
55                    },
56                    ..self.clone()
57                }
58                .expand(),
59            );
60            return attrs;
61        }
62        let args = self.args.iter().map(|ty| types::expand_type_wildcard(ty));
63        let ret = types::expand_type_wildcard(&self.ret);
64        let mut attrs = Vec::new();
65        for (args, mut ret) in args.multi_cartesian_product().cartesian_product(ret) {
66            if ret == "auto" {
67                ret = types::min_compatible_type(&args);
68            }
69            let attr = FunctionAttr {
70                args: args.iter().map(|s| s.to_string()).collect(),
71                ret: ret.to_owned(),
72                ..self.clone()
73            };
74            attrs.push(attr);
75        }
76        attrs
77    }
78
79    /// Generate the type infer function: `fn(&[DataType]) -> Result<DataType>`
80    fn generate_type_infer_fn(&self) -> Result<TokenStream2> {
81        if let Some(func) = &self.type_infer {
82            if func == "unreachable" {
83                return Ok(
84                    quote! { |_| unreachable!("type inference for this function should be specially handled in frontend, and should not call sig.type_infer") },
85                );
86            }
87            // use the user defined type inference function
88            return Ok(func.parse().unwrap());
89        } else if self.ret == "any" {
90            // TODO: if there are multiple "any", they should be the same type
91            if let Some(i) = self.args.iter().position(|t| t == "any") {
92                // infer as the type of "any" argument
93                return Ok(quote! { |args| Ok(args[#i].clone()) });
94            }
95            if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
96                // infer as the element type of "anyarray" argument
97                return Ok(quote! { |args| Ok(args[#i].as_list_elem().clone()) });
98            }
99        } else if self.ret == "anyarray" {
100            if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
101                // infer as the type of "anyarray" argument
102                return Ok(quote! { |args| Ok(args[#i].clone()) });
103            }
104            if let Some(i) = self.args.iter().position(|t| t == "any") {
105                // infer as the array type of "any" argument
106                return Ok(quote! { |args| Ok(DataType::list(args[#i].clone())) });
107            }
108        } else if self.ret == "struct" {
109            if let Some(i) = self.args.iter().position(|t| t == "struct") {
110                // infer as the type of "struct" argument
111                return Ok(quote! { |args| Ok(args[#i].clone()) });
112            }
113        } else if self.ret == "anymap" {
114            if let Some(i) = self.args.iter().position(|t| t == "anymap") {
115                // infer as the type of "anymap" argument
116                return Ok(quote! { |args| Ok(args[#i].clone()) });
117            }
118        } else {
119            // the return type is fixed
120            let ty = data_type(&self.ret);
121            return Ok(quote! { |_| Ok(#ty) });
122        }
123        Err(Error::new(
124            Span::call_site(),
125            "type inference function cannot be automatically derived. You should provide: `type_infer = \"|args| Ok(...)\"`",
126        ))
127    }
128
129    /// Generate a descriptor (`FuncSign`) of the scalar or table function.
130    ///
131    /// The types of arguments and return value should not contain wildcard.
132    ///
133    /// # Arguments
134    /// `build_fn`: whether the user provided a function is a build function.
135    /// (from the `#[build_function]` macro)
136    pub fn generate_function_descriptor(
137        &self,
138        user_fn: &UserFunctionAttr,
139        build_fn: bool,
140    ) -> Result<TokenStream2> {
141        if self.is_table_function {
142            return self.generate_table_function_descriptor(user_fn, build_fn);
143        }
144        let name = self.name.clone();
145        let variadic = matches!(self.args.last(), Some(t) if t == "...");
146        let args = match variadic {
147            true => &self.args[..self.args.len() - 1],
148            false => &self.args[..],
149        }
150        .iter()
151        .map(|ty| sig_data_type(ty))
152        .collect_vec();
153        let ret = sig_data_type(&self.ret);
154
155        let pb_type = format_ident!("{}", utils::to_camel_case(&name));
156        let ctor_name = format_ident!("{}", self.ident_name());
157        let build_fn = if build_fn {
158            let name = format_ident!("{}", user_fn.name);
159            quote! { #name }
160        } else if self.rewritten {
161            quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
162        } else {
163            // This is the core logic for `#[function]`
164            self.generate_build_scalar_function(user_fn, true)?
165        };
166        let type_infer_fn = self.generate_type_infer_fn()?;
167        let deprecated = self.deprecated;
168
169        Ok(quote! {
170            #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
171            fn #ctor_name() -> risingwave_expr::sig::FuncSign {
172                use risingwave_common::types::{DataType, DataTypeName};
173                use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
174
175                FuncSign {
176                    name: risingwave_pb::expr::expr_node::Type::#pb_type.into(),
177                    inputs_type: vec![#(#args),*],
178                    variadic: #variadic,
179                    ret_type: #ret,
180                    build: FuncBuilder::Scalar(#build_fn),
181                    type_infer: #type_infer_fn,
182                    deprecated: #deprecated,
183                }
184            }
185        })
186    }
187
188    /// Generate a build function for the scalar function.
189    ///
190    /// If `optimize_const` is true, the function will be optimized for constant arguments,
191    /// and fallback to the general version if any argument is not constant.
192    fn generate_build_scalar_function(
193        &self,
194        user_fn: &UserFunctionAttr,
195        optimize_const: bool,
196    ) -> Result<TokenStream2> {
197        let variadic = matches!(self.args.last(), Some(t) if t == "...");
198        let num_args = self.args.len() - if variadic { 1 } else { 0 };
199        let fn_name = format_ident!("{}", user_fn.name);
200        let struct_name = match optimize_const {
201            true => format_ident!("{}OptimizeConst", utils::to_camel_case(&self.ident_name())),
202            false => format_ident!("{}", utils::to_camel_case(&self.ident_name())),
203        };
204
205        // we divide all arguments into two groups: prebuilt and non-prebuilt.
206        // prebuilt arguments are collected from the "prebuild" field.
207        // let's say we have a function with 3 arguments: [0, 1, 2]
208        // and the prebuild field contains "$1".
209        // then we have:
210        //     prebuilt_indices = [1]
211        //     non_prebuilt_indices = [0, 2]
212        //
213        // if the const argument optimization is enabled, prebuilt arguments are
214        // evaluated at build time, thus the children only contain non-prebuilt arguments:
215        //     children_indices = [0, 2]
216        // otherwise, the children contain all arguments:
217        //     children_indices = [0, 1, 2]
218
219        let prebuilt_indices = match &self.prebuild {
220            Some(s) => (0..num_args)
221                .filter(|i| s.contains(&format!("${i}")))
222                .collect_vec(),
223            None => vec![],
224        };
225        let non_prebuilt_indices = match &self.prebuild {
226            Some(s) => (0..num_args)
227                .filter(|i| !s.contains(&format!("${i}")))
228                .collect_vec(),
229            _ => (0..num_args).collect_vec(),
230        };
231        let children_indices = match optimize_const {
232            #[allow(clippy::redundant_clone)] // false-positive
233            true => non_prebuilt_indices.clone(),
234            false => (0..num_args).collect_vec(),
235        };
236
237        /// Return a list of identifiers with the given prefix and indices.
238        fn idents(prefix: &str, indices: &[usize]) -> Vec<Ident> {
239            indices
240                .iter()
241                .map(|i| format_ident!("{prefix}{i}"))
242                .collect()
243        }
244        let inputs = idents("i", &children_indices);
245        let prebuilt_inputs = idents("i", &prebuilt_indices);
246        let non_prebuilt_inputs = idents("i", &non_prebuilt_indices);
247        let array_refs = idents("array", &children_indices);
248        let arrays = idents("a", &children_indices);
249        let datums = idents("v", &children_indices);
250        let arg_arrays = children_indices
251            .iter()
252            .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
253        let arg_types = children_indices.iter().map(|i| {
254            types::ref_type(&self.args[*i])
255                .parse::<TokenStream2>()
256                .unwrap()
257        });
258        let annotation: TokenStream2 = match user_fn.core_return_type.as_str() {
259            // add type annotation for functions that return generic types
260            "T" | "T1" | "T2" | "T3" => format!(": Option<{}>", types::owned_type(&self.ret))
261                .parse()
262                .unwrap(),
263            _ => quote! {},
264        };
265        let ret_array_type = format_ident!("{}", types::array_type(&self.ret));
266        let builder_type = format_ident!("{}Builder", types::array_type(&self.ret));
267        let prebuilt_arg_type = match &self.prebuild {
268            Some(s) if optimize_const => s.split("::").next().unwrap().parse().unwrap(),
269            _ => quote! { () },
270        };
271        let prebuilt_arg_value = match &self.prebuild {
272            // example:
273            // prebuild = "RegexContext::new($1)"
274            // return = "RegexContext::new(i1)"
275            Some(s) => s
276                .replace('$', "i")
277                .parse()
278                .expect("invalid prebuild syntax"),
279            None => quote! { () },
280        };
281        let prebuild_const = if self.prebuild.is_some() && optimize_const {
282            let build_general = self.generate_build_scalar_function(user_fn, false)?;
283            quote! {{
284                let build_general = #build_general;
285                #(
286                    // try to evaluate constant for prebuilt arguments
287                    let #prebuilt_inputs = match children[#prebuilt_indices].eval_const() {
288                        Ok(s) => s,
289                        // prebuilt argument is not constant, fallback to general
290                        Err(_) => return build_general(return_type, children),
291                    };
292                    // get reference to the constant value
293                    let #prebuilt_inputs = match &#prebuilt_inputs {
294                        Some(s) => s.as_scalar_ref_impl().try_into()?,
295                        // the function should always return null if any const argument is null
296                        None => return Ok(Box::new(risingwave_expr::expr::LiteralExpression::new(
297                            return_type,
298                            None,
299                        ))),
300                    };
301                )*
302                #prebuilt_arg_value
303            }}
304        } else {
305            quote! { () }
306        };
307
308        // ensure the number of children matches the number of arguments
309        let check_children = match variadic {
310            true => quote! { risingwave_expr::ensure!(children.len() >= #num_args); },
311            false => quote! { risingwave_expr::ensure!(children.len() == #num_args); },
312        };
313
314        // evaluate variadic arguments in `eval`
315        let eval_variadic = variadic.then(|| {
316            quote! {
317                let mut columns = Vec::with_capacity(self.children.len() - #num_args);
318                for child in &self.children[#num_args..] {
319                    columns.push(child.eval(input).await?);
320                }
321                let variadic_input = DataChunk::new(columns, input.visibility().clone());
322            }
323        });
324        // evaluate variadic arguments in `eval_row`
325        let eval_row_variadic = variadic.then(|| {
326            quote! {
327                let mut row = Vec::with_capacity(self.children.len() - #num_args);
328                for child in &self.children[#num_args..] {
329                    row.push(child.eval_row(input).await?);
330                }
331                let variadic_row = OwnedRow::new(row);
332            }
333        });
334
335        let generic = (self.ret == "boolean" && user_fn.generic == 3).then(|| {
336            // XXX: for generic compare functions, we need to specify the compatible type
337            let compatible_type = types::ref_type(types::min_compatible_type(&self.args))
338                .parse::<TokenStream2>()
339                .unwrap();
340            quote! { ::<_, _, #compatible_type> }
341        });
342        let prebuilt_arg = match (&self.prebuild, optimize_const) {
343            // use the prebuilt argument
344            (Some(_), true) => quote! { &self.prebuilt_arg, },
345            // build the argument on site
346            (Some(_), false) => quote! { &#prebuilt_arg_value, },
347            // no prebuilt argument
348            (None, _) => quote! {},
349        };
350        let variadic_args = variadic.then(|| quote! { &variadic_row, });
351        let context = user_fn.context.then(|| quote! { &self.context, });
352        let writer = user_fn
353            .writer_type_kind
354            .is_some()
355            .then(|| quote! { &mut writer, });
356        let await_ = user_fn.async_.then(|| quote! { .await });
357
358        let record_error = {
359            // Uniform arguments into `DatumRef`.
360            #[allow(clippy::disallowed_methods)] // allow zip
361            let inputs_args = inputs
362                .iter()
363                .zip(user_fn.args_option.iter())
364                .map(|(input, opt)| {
365                    if *opt {
366                        quote! { #input.map(|s| ScalarRefImpl::from(s)) }
367                    } else {
368                        quote! { Some(ScalarRefImpl::from(#input)) }
369                    }
370                });
371            let inputs_args = quote! {
372                let args: &[DatumRef<'_>] = &[#(#inputs_args),*];
373                let args = args.iter().copied();
374            };
375            let var_args = variadic.then(|| {
376                quote! {
377                    let args = args.chain(variadic_row.iter());
378                }
379            });
380
381            quote! {
382                #inputs_args
383                #var_args
384                errors.push(ExprError::function(
385                    stringify!(#fn_name),
386                    args,
387                    e,
388                ));
389            }
390        };
391
392        // call the user defined function
393        // inputs: [ Option<impl ScalarRef> ]
394        let mut output = quote! { #fn_name #generic(
395            #(#non_prebuilt_inputs,)*
396            #variadic_args
397            #prebuilt_arg
398            #context
399            #writer
400        ) #await_ };
401        // handle error if the function returns `Result`
402        // wrap a `Some` if the function doesn't return `Option`
403        output = match user_fn.return_type_kind {
404            // XXX: we don't support void type yet. return null::int for now.
405            _ if self.ret == "void" => quote! { { #output; Option::<i32>::None } },
406            ReturnTypeKind::T => quote! { Some(#output) },
407            ReturnTypeKind::Option => output,
408            ReturnTypeKind::Result => quote! {
409                match #output {
410                    Ok(x) => Some(x),
411                    Err(e) => {
412                        #record_error
413                        None
414                    }
415                }
416            },
417            ReturnTypeKind::ResultOption => quote! {
418                match #output {
419                    Ok(x) => x,
420                    Err(e) => {
421                        #record_error
422                        None
423                    }
424                }
425            },
426        };
427        // if user function accepts non-option arguments, we assume the function
428        // returns null on null input, so we need to unwrap the inputs before calling.
429        if self.prebuild.is_some() {
430            output = quote! {
431                match (#(#inputs,)*) {
432                    (#(Some(#inputs),)*) => #output,
433                    _ => None,
434                }
435            };
436        } else {
437            #[allow(clippy::disallowed_methods)] // allow zip
438            let some_inputs = inputs
439                .iter()
440                .zip(user_fn.args_option.iter())
441                .map(|(input, opt)| {
442                    if *opt {
443                        quote! { #input }
444                    } else {
445                        quote! { Some(#input) }
446                    }
447                });
448            output = quote! {
449                match (#(#inputs,)*) {
450                    (#(#some_inputs,)*) => #output,
451                    _ => None,
452                }
453            };
454        };
455        // now the `output` is: Option<impl ScalarRef or Scalar>
456        let append_output = match user_fn.writer_type_kind {
457            Some(WriterTypeKind::FmtWrite)
458            | Some(WriterTypeKind::IoWrite)
459            | Some(WriterTypeKind::ListWrite) => quote! {{
460                let mut writer = builder.writer();
461                if #output.is_some() {
462                    writer.finish();
463                } else {
464                    writer.rollback();
465                    builder.append_null();
466                }
467            }},
468            Some(WriterTypeKind::JsonbbBuilder) => quote! {{
469                let mut writer_wrapper = builder.writer();
470                let mut writer = writer_wrapper.inner();
471                if #output.is_some() {
472                    writer_wrapper.finish();
473                } else {
474                    writer_wrapper.rollback();
475                    builder.append_null();
476                }
477            }},
478            None if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
479                builder.append(#output.as_ref().map(|s| s.as_ref()));
480            },
481            None => quote! {
482                let output #annotation = #output;
483                builder.append(output.as_ref().map(|s| s.as_scalar_ref()));
484            },
485        };
486        // the output expression in `eval_row`
487        let row_output = match user_fn.writer_type_kind {
488            Some(WriterTypeKind::FmtWrite) => quote! {{
489                let mut writer = String::new();
490                #output.map(|_| writer.into())
491            }},
492            Some(WriterTypeKind::IoWrite) => quote! {{
493                let mut writer = Vec::new();
494                #output.map(|_| writer.into())
495            }},
496            Some(WriterTypeKind::JsonbbBuilder) => quote! {{
497                let mut writer = jsonbb::Builder::<Vec<u8>>::new();
498                #output.map(|_| JsonbVal::from(writer.finish()).into())
499            }},
500            Some(WriterTypeKind::ListWrite) => quote! {{
501                let mut writer = {
502                    let DataType::List(list_ty) = &self.context.return_type else {
503                        panic!("data type must be DataType::List");
504                    };
505                    list_ty.elem().create_array_builder(1)
506                };
507                #output.map(|_| ListValue::new(writer.finish()).into())
508            }},
509            None if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
510                #output.map(|s| s.as_ref().into())
511            },
512            None => quote! {{
513                let output #annotation = #output;
514                output.map(|s| s.into())
515            }},
516        };
517        // the main body in `eval`
518        let eval = if let Some(batch_fn) = &self.batch_fn {
519            assert!(
520                !variadic,
521                "customized batch function is not supported for variadic functions"
522            );
523            // user defined batch function
524            let fn_name = format_ident!("{}", batch_fn);
525            quote! {
526                let c = #fn_name(#(#arrays),*);
527                Arc::new(c.into())
528            }
529        } else if (types::is_primitive(&self.ret) || self.ret == "boolean")
530            && user_fn.is_pure()
531            && !variadic
532            && self.prebuild.is_none()
533        {
534            // SIMD optimization for primitive types
535            match self.args.len() {
536                0 => quote! {
537                    let c = #ret_array_type::from_iter_bitmap(
538                        std::iter::repeat_with(|| #fn_name()).take(input.capacity()),
539                        Bitmap::ones(input.capacity()),
540                    );
541                    Arc::new(c.into())
542                },
543                1 => quote! {
544                    let c = #ret_array_type::from_iter_bitmap(
545                        a0.raw_iter().map(|a| #fn_name(a)),
546                        a0.null_bitmap().clone()
547                    );
548                    Arc::new(c.into())
549                },
550                2 => quote! {
551                    // allow using `zip` for performance
552                    #[allow(clippy::disallowed_methods)]
553                    let c = #ret_array_type::from_iter_bitmap(
554                        a0.raw_iter()
555                            .zip(a1.raw_iter())
556                            .map(|(a, b)| #fn_name #generic(a, b)),
557                        a0.null_bitmap() & a1.null_bitmap(),
558                    );
559                    Arc::new(c.into())
560                },
561                n => todo!("SIMD optimization for {n} arguments"),
562            }
563        } else {
564            // no optimization
565            let let_variadic = variadic.then(|| {
566                quote! {
567                    let variadic_row = variadic_input.row_at_unchecked_vis(i);
568                }
569            });
570            quote! {
571                let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone());
572
573                if input.is_vis_compacted() {
574                    for i in 0..input.capacity() {
575                        #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
576                        #let_variadic
577                        #append_output
578                    }
579                } else {
580                    for i in 0..input.capacity() {
581                        if unsafe { !input.visibility().is_set_unchecked(i) } {
582                            builder.append_null();
583                            continue;
584                        }
585                        #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
586                        #let_variadic
587                        #append_output
588                    }
589                }
590                Arc::new(builder.finish().into())
591            }
592        };
593
594        Ok(quote! {
595            |return_type: DataType, children: Vec<risingwave_expr::expr::BoxedExpression>|
596                -> risingwave_expr::Result<risingwave_expr::expr::BoxedExpression>
597            {
598                use std::sync::Arc;
599                use risingwave_common::array::*;
600                use risingwave_common::types::*;
601                use risingwave_common::bitmap::Bitmap;
602                use risingwave_common::row::OwnedRow;
603                use risingwave_common::util::iter_util::ZipEqFast;
604
605                use risingwave_expr::expr::{Context, BoxedExpression};
606                use risingwave_expr::{ExprError, Result};
607                use risingwave_expr::codegen::*;
608
609                #check_children
610                let prebuilt_arg = #prebuild_const;
611                let context = Context {
612                    return_type,
613                    arg_types: children.iter().map(|c| c.return_type()).collect(),
614                    variadic: #variadic,
615                };
616
617                #[derive(Debug)]
618                struct #struct_name {
619                    context: Context,
620                    children: Vec<BoxedExpression>,
621                    prebuilt_arg: #prebuilt_arg_type,
622                }
623                #[async_trait]
624                impl risingwave_expr::expr::Expression for #struct_name {
625                    fn return_type(&self) -> DataType {
626                        self.context.return_type.clone()
627                    }
628                    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
629                        #(
630                            let #array_refs = self.children[#children_indices].eval(input).await?;
631                            let #arrays: &#arg_arrays = #array_refs.as_ref().into();
632                        )*
633                        #eval_variadic
634                        let mut errors = vec![];
635                        let array = { #eval };
636                        if errors.is_empty() {
637                            Ok(array)
638                        } else {
639                            Err(ExprError::Multiple(array, errors.into()))
640                        }
641                    }
642                    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
643                        #(
644                            let #datums = self.children[#children_indices].eval_row(input).await?;
645                            let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap());
646                        )*
647                        #eval_row_variadic
648                        let mut errors: Vec<ExprError> = vec![];
649                        let output = #row_output;
650                        if let Some(err) = errors.into_iter().next() {
651                            Err(err.into())
652                        } else {
653                            Ok(output)
654                        }
655                    }
656                }
657
658                Ok(Box::new(#struct_name {
659                    context,
660                    children,
661                    prebuilt_arg,
662                }))
663            }
664        })
665    }
666
667    /// Generate a descriptor of the aggregate function.
668    ///
669    /// The types of arguments and return value should not contain wildcard.
670    /// `user_fn` could be either `fn` or `impl`.
671    /// If `build_fn` is true, `user_fn` must be a `fn` that builds the aggregate function.
672    pub fn generate_aggregate_descriptor(
673        &self,
674        user_fn: &AggregateFnOrImpl,
675        build_fn: bool,
676    ) -> Result<TokenStream2> {
677        let name = self.name.clone();
678
679        let mut args = Vec::with_capacity(self.args.len());
680        for ty in &self.args {
681            args.push(sig_data_type(ty));
682        }
683        let ret = sig_data_type(&self.ret);
684        let state_type = match &self.state {
685            Some(ty) if ty != "ref" => {
686                let ty = data_type(ty);
687                quote! { Some(#ty) }
688            }
689            _ => quote! { None },
690        };
691        let append_only = match build_fn {
692            false => !user_fn.has_retract(),
693            true => self.append_only,
694        };
695
696        let pb_kind = format_ident!("{}", utils::to_camel_case(&name));
697        let ctor_name = match append_only {
698            false => format_ident!("{}", self.ident_name()),
699            true => format_ident!("{}_append_only", self.ident_name()),
700        };
701        let build_fn = if build_fn {
702            let name = format_ident!("{}", user_fn.as_fn().name);
703            quote! { #name }
704        } else if self.rewritten {
705            quote! { |_| Err(ExprError::UnsupportedFunction(#name.into())) }
706        } else {
707            self.generate_agg_build_fn(user_fn)?
708        };
709        let build_retractable = match append_only {
710            true => quote! { None },
711            false => quote! { Some(#build_fn) },
712        };
713        let build_append_only = match append_only {
714            false => quote! { None },
715            true => quote! { Some(#build_fn) },
716        };
717        let retractable_state_type = match append_only {
718            true => quote! { None },
719            false => state_type.clone(),
720        };
721        let append_only_state_type = match append_only {
722            false => quote! { None },
723            true => state_type,
724        };
725        let type_infer_fn = self.generate_type_infer_fn()?;
726        let deprecated = self.deprecated;
727
728        Ok(quote! {
729            #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
730            fn #ctor_name() -> risingwave_expr::sig::FuncSign {
731                use risingwave_common::types::{DataType, DataTypeName};
732                use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
733
734                FuncSign {
735                    name: risingwave_pb::expr::agg_call::PbKind::#pb_kind.into(),
736                    inputs_type: vec![#(#args),*],
737                    variadic: false,
738                    ret_type: #ret,
739                    build: FuncBuilder::Aggregate {
740                        retractable: #build_retractable,
741                        append_only: #build_append_only,
742                        retractable_state_type: #retractable_state_type,
743                        append_only_state_type: #append_only_state_type,
744                    },
745                    type_infer: #type_infer_fn,
746                    deprecated: #deprecated,
747                }
748            }
749        })
750    }
751
752    /// Generate build function for aggregate function.
753    fn generate_agg_build_fn(&self, user_fn: &AggregateFnOrImpl) -> Result<TokenStream2> {
754        // If the first argument of the aggregate function is of type `&mut T`,
755        // we assume it is a user defined state type.
756        let custom_state = user_fn.accumulate().first_mut_ref_arg.as_ref();
757        let state_type: TokenStream2 = match (custom_state, &self.state) {
758            (Some(s), _) => s.parse().unwrap(),
759            (_, Some(state)) if state == "ref" => types::ref_type(&self.ret).parse().unwrap(),
760            (_, Some(state)) if state != "ref" => types::owned_type(state).parse().unwrap(),
761            _ => types::owned_type(&self.ret).parse().unwrap(),
762        };
763        let let_arrays = self
764            .args
765            .iter()
766            .enumerate()
767            .map(|(i, arg)| {
768                let array = format_ident!("a{i}");
769                let array_type: TokenStream2 = types::array_type(arg).parse().unwrap();
770                quote! {
771                    let #array: &#array_type = input.column_at(#i).as_ref().into();
772                }
773            })
774            .collect_vec();
775        let let_values = (0..self.args.len())
776            .map(|i| {
777                let v = format_ident!("v{i}");
778                let a = format_ident!("a{i}");
779                quote! { let #v = unsafe { #a.value_at_unchecked(row_id) }; }
780            })
781            .collect_vec();
782        let downcast_state = if custom_state.is_some() {
783            quote! { let mut state: &mut #state_type = state0.downcast_mut(); }
784        } else if let Some(s) = &self.state
785            && s == "ref"
786        {
787            quote! { let mut state: Option<#state_type> = state0.as_datum_mut().as_ref().map(|x| x.as_scalar_ref_impl().try_into().unwrap()); }
788        } else {
789            quote! { let mut state: Option<#state_type> = state0.as_datum_mut().take().map(|s| s.try_into().unwrap()); }
790        };
791        let restore_state = if custom_state.is_some() {
792            quote! {}
793        } else if let Some(s) = &self.state
794            && s == "ref"
795        {
796            quote! { *state0.as_datum_mut() = state.map(|x| x.to_owned_scalar().into()); }
797        } else {
798            quote! { *state0.as_datum_mut() = state.map(|s| s.into()); }
799        };
800        let create_state = if custom_state.is_some() {
801            quote! {
802                fn create_state(&self) -> Result<AggregateState> {
803                    Ok(AggregateState::Any(Box::<#state_type>::default()))
804                }
805            }
806        } else if let Some(state) = &self.init_state {
807            let state: TokenStream2 = state.parse().unwrap();
808            quote! {
809                fn create_state(&self) -> Result<AggregateState> {
810                    Ok(AggregateState::Datum(Some(#state.into())))
811                }
812            }
813        } else {
814            // by default: `AggregateState::Datum(None)`
815            quote! {}
816        };
817        let args = (0..self.args.len()).map(|i| format_ident!("v{i}"));
818        let args = quote! { #(#args,)* };
819        let panic_on_retract = {
820            let msg = format!(
821                "attempt to retract on aggregate function {}, but it is append-only",
822                self.name
823            );
824            quote! { assert_eq!(op, Op::Insert, #msg); }
825        };
826        let mut next_state = match user_fn {
827            AggregateFnOrImpl::Fn(f) => {
828                let context = f.context.then(|| quote! { &self.context, });
829                let fn_name = format_ident!("{}", f.name);
830                match f.retract {
831                    true => {
832                        quote! { #fn_name(state, #args matches!(op, Op::Delete | Op::UpdateDelete) #context) }
833                    }
834                    false => quote! {{
835                        #panic_on_retract
836                        #fn_name(state, #args #context)
837                    }},
838                }
839            }
840            AggregateFnOrImpl::Impl(i) => {
841                let retract = match i.retract {
842                    Some(_) => quote! { self.function.retract(state, #args) },
843                    None => panic_on_retract,
844                };
845                quote! {
846                    if matches!(op, Op::Delete | Op::UpdateDelete) {
847                        #retract
848                    } else {
849                        self.function.accumulate(state, #args)
850                    }
851                }
852            }
853        };
854        next_state = match user_fn.accumulate().return_type_kind {
855            ReturnTypeKind::T => quote! { Some(#next_state) },
856            ReturnTypeKind::Option => next_state,
857            ReturnTypeKind::Result => quote! { Some(#next_state?) },
858            ReturnTypeKind::ResultOption => quote! { #next_state? },
859        };
860        if user_fn.accumulate().args_option.iter().all(|b| !b) {
861            match self.args.len() {
862                0 => {
863                    next_state = quote! {
864                        match state {
865                            Some(state) => #next_state,
866                            None => state,
867                        }
868                    };
869                }
870                1 => {
871                    let first_state = if self.init_state.is_some() {
872                        // for count, the state will never be None
873                        quote! { unreachable!() }
874                    } else if let Some(s) = &self.state
875                        && s == "ref"
876                    {
877                        // for min/max/first/last, the state is the first value
878                        quote! { Some(v0) }
879                    } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
880                        && impl_.create_state.is_some()
881                    {
882                        // use user-defined create_state function
883                        quote! {{
884                            let state = self.function.create_state();
885                            #next_state
886                        }}
887                    } else {
888                        quote! {{
889                            let state = #state_type::default();
890                            #next_state
891                        }}
892                    };
893                    next_state = quote! {
894                        match (state, v0) {
895                            (Some(state), Some(v0)) => #next_state,
896                            (None, Some(v0)) => #first_state,
897                            (state, None) => state,
898                        }
899                    };
900                }
901                _ => todo!("multiple arguments are not supported for non-option function"),
902            }
903        }
904        let update_state = if custom_state.is_some() {
905            quote! { _ = #next_state; }
906        } else {
907            quote! { state = #next_state; }
908        };
909        let get_result = if custom_state.is_some() {
910            quote! { Ok(state.downcast_ref::<#state_type>().into()) }
911        } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
912            && impl_.finalize.is_some()
913        {
914            quote! {
915                let state = match state.as_datum() {
916                    Some(s) => s.as_scalar_ref_impl().try_into().unwrap(),
917                    None => return Ok(None),
918                };
919                Ok(Some(self.function.finalize(state).into()))
920            }
921        } else {
922            quote! { Ok(state.as_datum().clone()) }
923        };
924        let function_field = match user_fn {
925            AggregateFnOrImpl::Fn(_) => quote! {},
926            AggregateFnOrImpl::Impl(i) => {
927                let struct_name = format_ident!("{}", i.struct_name);
928                let generic = self.generic.as_ref().map(|g| {
929                    let g = format_ident!("{g}");
930                    quote! { <#g> }
931                });
932                quote! { function: #struct_name #generic, }
933            }
934        };
935        let function_new = match user_fn {
936            AggregateFnOrImpl::Fn(_) => quote! {},
937            AggregateFnOrImpl::Impl(i) => {
938                let struct_name = format_ident!("{}", i.struct_name);
939                let generic = self.generic.as_ref().map(|g| {
940                    let g = format_ident!("{g}");
941                    quote! { ::<#g> }
942                });
943                quote! { function: #struct_name #generic :: default(), }
944            }
945        };
946
947        Ok(quote! {
948            |agg| {
949                use std::collections::HashSet;
950                use std::ops::Range;
951                use risingwave_common::array::*;
952                use risingwave_common::types::*;
953                use risingwave_common::bail;
954                use risingwave_common::bitmap::Bitmap;
955                use risingwave_common_estimate_size::EstimateSize;
956
957                use risingwave_expr::expr::Context;
958                use risingwave_expr::Result;
959                use risingwave_expr::aggregate::AggregateState;
960                use risingwave_expr::codegen::async_trait;
961
962                let context = Context {
963                    return_type: agg.return_type.clone(),
964                    arg_types: agg.args.arg_types().to_owned(),
965                    variadic: false,
966                };
967
968                struct Agg {
969                    context: Context,
970                    #function_field
971                }
972
973                #[async_trait]
974                impl risingwave_expr::aggregate::AggregateFunction for Agg {
975                    fn return_type(&self) -> DataType {
976                        self.context.return_type.clone()
977                    }
978
979                    #create_state
980
981                    async fn update(&self, state0: &mut AggregateState, input: &StreamChunk) -> Result<()> {
982                        #(#let_arrays)*
983                        #downcast_state
984                        for row_id in input.visibility().iter_ones() {
985                            let op = unsafe { *input.ops().get_unchecked(row_id) };
986                            #(#let_values)*
987                            #update_state
988                        }
989                        #restore_state
990                        Ok(())
991                    }
992
993                    async fn update_range(&self, state0: &mut AggregateState, input: &StreamChunk, range: Range<usize>) -> Result<()> {
994                        assert!(range.end <= input.capacity());
995                        #(#let_arrays)*
996                        #downcast_state
997                        if input.is_vis_compacted() {
998                            for row_id in range {
999                                let op = unsafe { *input.ops().get_unchecked(row_id) };
1000                                #(#let_values)*
1001                                #update_state
1002                            }
1003                        } else {
1004                            for row_id in input.visibility().iter_ones() {
1005                                if row_id < range.start {
1006                                    continue;
1007                                } else if row_id >= range.end {
1008                                    break;
1009                                }
1010                                let op = unsafe { *input.ops().get_unchecked(row_id) };
1011                                #(#let_values)*
1012                                #update_state
1013                            }
1014                        }
1015                        #restore_state
1016                        Ok(())
1017                    }
1018
1019                    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
1020                        #get_result
1021                    }
1022                }
1023
1024                Ok(Box::new(Agg {
1025                    context,
1026                    #function_new
1027                }))
1028            }
1029        })
1030    }
1031
1032    /// Generate a descriptor of the table function.
1033    ///
1034    /// The types of arguments and return value should not contain wildcard.
1035    fn generate_table_function_descriptor(
1036        &self,
1037        user_fn: &UserFunctionAttr,
1038        build_fn: bool,
1039    ) -> Result<TokenStream2> {
1040        let name = self.name.clone();
1041        let mut args = Vec::with_capacity(self.args.len());
1042        for ty in &self.args {
1043            args.push(sig_data_type(ty));
1044        }
1045        let ret = sig_data_type(&self.ret);
1046
1047        let pb_type = format_ident!("{}", utils::to_camel_case(&name));
1048        let ctor_name = format_ident!("{}", self.ident_name());
1049        let build_fn = if build_fn {
1050            let name = format_ident!("{}", user_fn.name);
1051            quote! { #name }
1052        } else if self.rewritten {
1053            quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
1054        } else {
1055            self.generate_build_table_function(user_fn)?
1056        };
1057        let type_infer_fn = self.generate_type_infer_fn()?;
1058        let deprecated = self.deprecated;
1059
1060        Ok(quote! {
1061            #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
1062            fn #ctor_name() -> risingwave_expr::sig::FuncSign {
1063                use risingwave_common::types::{DataType, DataTypeName};
1064                use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
1065
1066                FuncSign {
1067                    name: risingwave_pb::expr::table_function::Type::#pb_type.into(),
1068                    inputs_type: vec![#(#args),*],
1069                    variadic: false,
1070                    ret_type: #ret,
1071                    build: FuncBuilder::Table(#build_fn),
1072                    type_infer: #type_infer_fn,
1073                    deprecated: #deprecated,
1074                }
1075            }
1076        })
1077    }
1078
1079    fn generate_build_table_function(&self, user_fn: &UserFunctionAttr) -> Result<TokenStream2> {
1080        let num_args = self.args.len();
1081        let return_types = output_types(&self.ret);
1082        let fn_name = format_ident!("{}", user_fn.name);
1083        let struct_name = format_ident!("{}", utils::to_camel_case(&self.ident_name()));
1084        let arg_ids = (0..num_args)
1085            .filter(|i| match &self.prebuild {
1086                Some(s) => !s.contains(&format!("${i}")),
1087                None => true,
1088            })
1089            .collect_vec();
1090        let const_ids = (0..num_args).filter(|i| match &self.prebuild {
1091            Some(s) => s.contains(&format!("${i}")),
1092            None => false,
1093        });
1094        let inputs: Vec<_> = arg_ids.iter().map(|i| format_ident!("i{i}")).collect();
1095        let all_child: Vec<_> = (0..num_args).map(|i| format_ident!("child{i}")).collect();
1096        let const_child: Vec<_> = const_ids.map(|i| format_ident!("child{i}")).collect();
1097        let child: Vec<_> = arg_ids.iter().map(|i| format_ident!("child{i}")).collect();
1098        let array_refs: Vec<_> = arg_ids.iter().map(|i| format_ident!("array{i}")).collect();
1099        let arrays: Vec<_> = arg_ids.iter().map(|i| format_ident!("a{i}")).collect();
1100        let arg_arrays = arg_ids
1101            .iter()
1102            .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
1103        let outputs = (0..return_types.len())
1104            .map(|i| format_ident!("o{i}"))
1105            .collect_vec();
1106        let builders = (0..return_types.len())
1107            .map(|i| format_ident!("builder{i}"))
1108            .collect_vec();
1109        let builder_types = return_types
1110            .iter()
1111            .map(|ty| format_ident!("{}Builder", types::array_type(ty)))
1112            .collect_vec();
1113        let return_types = if return_types.len() == 1 {
1114            vec![quote! { self.context.return_type.clone() }]
1115        } else {
1116            (0..return_types.len())
1117                .map(|i| quote! { self.context.return_type.as_struct().types().nth(#i).unwrap().clone() })
1118                .collect()
1119        };
1120        #[allow(clippy::disallowed_methods)]
1121        let optioned_outputs = user_fn
1122            .core_return_type
1123            .split(',')
1124            .map(|t| t.contains("Option"))
1125            // example: "(Option<&str>, i32)" => [true, false]
1126            .zip(&outputs)
1127            .map(|(optional, o)| match optional {
1128                false => quote! { Some(#o.as_scalar_ref()) },
1129                true => quote! { #o.map(|o| o.as_scalar_ref()) },
1130            })
1131            .collect_vec();
1132        let build_value_array = if return_types.len() == 1 {
1133            quote! { let [value_array] = value_arrays; }
1134        } else {
1135            quote! {
1136                let value_array = StructArray::new(
1137                    self.context.return_type.as_struct().clone(),
1138                    value_arrays.to_vec(),
1139                    Bitmap::ones(len),
1140                ).into_ref();
1141            }
1142        };
1143        let context = user_fn.context.then(|| quote! { &self.context, });
1144        let prebuilt_arg = match &self.prebuild {
1145            Some(_) => quote! { &self.prebuilt_arg, },
1146            None => quote! {},
1147        };
1148        let prebuilt_arg_type = match &self.prebuild {
1149            Some(s) => s.split("::").next().unwrap().parse().unwrap(),
1150            None => quote! { () },
1151        };
1152        let prebuilt_arg_value = match &self.prebuild {
1153            Some(s) => s
1154                .replace('$', "child")
1155                .parse()
1156                .expect("invalid prebuild syntax"),
1157            None => quote! { () },
1158        };
1159        let iter = quote! { #fn_name(#(#inputs,)* #prebuilt_arg #context) };
1160        let mut iter = match user_fn.return_type_kind {
1161            ReturnTypeKind::T => quote! { #iter },
1162            ReturnTypeKind::Option => quote! { match #iter {
1163                Some(it) => it,
1164                None => continue,
1165            } },
1166            ReturnTypeKind::Result => quote! { match #iter {
1167                Ok(it) => it,
1168                Err(e) => {
1169                    index_builder.append(Some(i as i32));
1170                    #(#builders.append_null();)*
1171                    error_builder.append_display(Some(e.as_report()));
1172                    continue;
1173                }
1174            } },
1175            ReturnTypeKind::ResultOption => quote! { match #iter {
1176                Ok(Some(it)) => it,
1177                Ok(None) => continue,
1178                Err(e) => {
1179                    index_builder.append(Some(i as i32));
1180                    #(#builders.append_null();)*
1181                    error_builder.append_display(Some(e.as_report()));
1182                    continue;
1183                }
1184            } },
1185        };
1186        // if user function accepts non-option arguments, we assume the function
1187        // returns empty on null input, so we need to unwrap the inputs before calling.
1188        #[allow(clippy::disallowed_methods)] // allow zip
1189        let some_inputs = inputs
1190            .iter()
1191            .zip(user_fn.args_option.iter())
1192            .map(|(input, opt)| {
1193                if *opt {
1194                    quote! { #input }
1195                } else {
1196                    quote! { Some(#input) }
1197                }
1198            });
1199        iter = quote! {
1200            match (#(#inputs,)*) {
1201                (#(#some_inputs,)*) => #iter,
1202                _ => continue,
1203            }
1204        };
1205        let iterator_item_type = user_fn.iterator_item_kind.clone().ok_or_else(|| {
1206            Error::new(
1207                user_fn.return_type_span,
1208                "expect `impl Iterator` in return type",
1209            )
1210        })?;
1211        let append_output = match iterator_item_type {
1212            ReturnTypeKind::T => quote! {
1213                let (#(#outputs),*) = output;
1214                #(#builders.append(#optioned_outputs);)* error_builder.append_null();
1215            },
1216            ReturnTypeKind::Option => quote! { match output {
1217                Some((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1218                None => { #(#builders.append_null();)* error_builder.append_null(); }
1219            } },
1220            ReturnTypeKind::Result => quote! { match output {
1221                Ok((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1222                Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1223            } },
1224            ReturnTypeKind::ResultOption => quote! { match output {
1225                Ok(Some((#(#outputs),*))) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1226                Ok(None) => { #(#builders.append_null();)* error_builder.append_null(); }
1227                Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1228            } },
1229        };
1230
1231        Ok(quote! {
1232            |return_type, chunk_size, children| {
1233                use risingwave_common::array::*;
1234                use risingwave_common::types::*;
1235                use risingwave_common::bitmap::Bitmap;
1236                use risingwave_common::util::iter_util::ZipEqFast;
1237                use risingwave_expr::expr::{BoxedExpression, Context};
1238                use risingwave_expr::{Result, ExprError};
1239                use risingwave_expr::codegen::*;
1240
1241                risingwave_expr::ensure!(children.len() == #num_args);
1242
1243                let context = Context {
1244                    return_type: return_type.clone(),
1245                    arg_types: children.iter().map(|c| c.return_type()).collect(),
1246                    variadic: false,
1247                };
1248
1249                let mut iter = children.into_iter();
1250                #(let #all_child = iter.next().unwrap();)*
1251                #(
1252                    let #const_child = #const_child.eval_const()?;
1253                    let #const_child = match &#const_child {
1254                        Some(s) => s.as_scalar_ref_impl().try_into()?,
1255                        // the function should always return empty if any const argument is null
1256                        None => return Ok(risingwave_expr::table_function::empty(return_type)),
1257                    };
1258                )*
1259
1260                #[derive(Debug)]
1261                struct #struct_name {
1262                    context: Context,
1263                    chunk_size: usize,
1264                    #(#child: BoxedExpression,)*
1265                    prebuilt_arg: #prebuilt_arg_type,
1266                }
1267                #[async_trait]
1268                impl risingwave_expr::table_function::TableFunction for #struct_name {
1269                    fn return_type(&self) -> DataType {
1270                        self.context.return_type.clone()
1271                    }
1272                    async fn eval<'a>(&'a self, input: &'a DataChunk) -> BoxStream<'a, Result<DataChunk>> {
1273                        self.eval_inner(input)
1274                    }
1275                }
1276                impl #struct_name {
1277                    #[try_stream(boxed, ok = DataChunk, error = ExprError)]
1278                    async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
1279                        #(
1280                        let #array_refs = self.#child.eval(input).await?;
1281                        let #arrays: &#arg_arrays = #array_refs.as_ref().into();
1282                        )*
1283
1284                        let mut index_builder = I32ArrayBuilder::new(self.chunk_size);
1285                        #(let mut #builders = #builder_types::with_type(self.chunk_size, #return_types);)*
1286                        let mut error_builder = Utf8ArrayBuilder::new(self.chunk_size);
1287
1288                        for i in 0..input.capacity() {
1289                            if unsafe { !input.visibility().is_set_unchecked(i) } {
1290                                continue;
1291                            }
1292                            #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
1293                            for output in #iter {
1294                                index_builder.append(Some(i as i32));
1295                                #append_output
1296
1297                                if index_builder.len() == self.chunk_size {
1298                                    let len = index_builder.len();
1299                                    let index_array = std::mem::replace(&mut index_builder, I32ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1300                                    let value_arrays = [#(std::mem::replace(&mut #builders, #builder_types::with_type(self.chunk_size, #return_types)).finish().into_ref()),*];
1301                                    #build_value_array
1302                                    let error_array = std::mem::replace(&mut error_builder, Utf8ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1303                                    if error_array.null_bitmap().any() {
1304                                        yield DataChunk::new(vec![index_array, value_array, error_array], self.chunk_size);
1305                                    } else {
1306                                        yield DataChunk::new(vec![index_array, value_array], self.chunk_size);
1307                                    }
1308                                }
1309                            }
1310                        }
1311
1312                        if index_builder.len() > 0 {
1313                            let len = index_builder.len();
1314                            let index_array = index_builder.finish().into_ref();
1315                            let value_arrays = [#(#builders.finish().into_ref()),*];
1316                            #build_value_array
1317                            let error_array = error_builder.finish().into_ref();
1318                            if error_array.null_bitmap().any() {
1319                                yield DataChunk::new(vec![index_array, value_array, error_array], len);
1320                            } else {
1321                                yield DataChunk::new(vec![index_array, value_array], len);
1322                            }
1323                        }
1324                    }
1325                }
1326
1327                Ok(Box::new(#struct_name {
1328                    context,
1329                    chunk_size,
1330                    #(#child,)*
1331                    prebuilt_arg: #prebuilt_arg_value,
1332                }))
1333            }
1334        })
1335    }
1336}
1337
1338fn sig_data_type(ty: &str) -> TokenStream2 {
1339    match ty {
1340        "any" => quote! { SigDataType::Any },
1341        "anyarray" => quote! { SigDataType::AnyArray },
1342        "anymap" => quote! { SigDataType::AnyMap },
1343        "vector" => quote! { SigDataType::Vector },
1344        "struct" => quote! { SigDataType::AnyStruct },
1345        _ if ty.starts_with("struct") && ty.contains("any") => quote! { SigDataType::AnyStruct },
1346        _ => {
1347            let datatype = data_type(ty);
1348            quote! { SigDataType::Exact(#datatype) }
1349        }
1350    }
1351}
1352
1353fn data_type(ty: &str) -> TokenStream2 {
1354    if let Some(ty) = ty.strip_suffix("[]") {
1355        let inner_type = data_type(ty);
1356        return quote! { DataType::list(#inner_type) };
1357    }
1358    if ty.starts_with("struct<") {
1359        return quote! { DataType::Struct(#ty.parse().expect("invalid struct type")) };
1360    }
1361    let variant = format_ident!("{}", types::data_type(ty));
1362    // TODO: enable the check
1363    // assert!(
1364    //     !matches!(ty, "any" | "anyarray" | "anymap" | "struct"),
1365    //     "{ty}, {variant}"
1366    // );
1367
1368    quote! { DataType::#variant }
1369}
1370
1371/// Extract multiple output types.
1372///
1373/// ```ignore
1374/// output_types("int4") -> ["int4"]
1375/// output_types("struct<key varchar, value jsonb>") -> ["varchar", "jsonb"]
1376/// ```
1377fn output_types(ty: &str) -> Vec<&str> {
1378    if let Some(s) = ty.strip_prefix("struct<")
1379        && let Some(args) = s.strip_suffix('>')
1380    {
1381        args.split(',')
1382            .map(|s| s.split_whitespace().nth(1).unwrap())
1383            .collect()
1384    } else {
1385        vec![ty]
1386    }
1387}