risingwave_expr_macro/
gen.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Generate code for the functions.
16
17use itertools::Itertools;
18use proc_macro2::{Ident, Span};
19use quote::{format_ident, quote};
20
21use super::*;
22
23impl FunctionAttr {
24    /// Expands the wildcard in function arguments or return type.
25    pub fn expand(&self) -> Vec<Self> {
26        // handle variadic argument
27        if self
28            .args
29            .last()
30            .is_some_and(|arg| arg.starts_with("variadic"))
31        {
32            // expand:  foo(a, b, variadic anyarray)
33            // to:      foo(a, b, ...)
34            //        + foo_variadic(a, b, anyarray)
35            let mut attrs = Vec::new();
36            attrs.extend(
37                FunctionAttr {
38                    args: {
39                        let mut args = self.args.clone();
40                        *args.last_mut().unwrap() = "...".to_owned();
41                        args
42                    },
43                    ..self.clone()
44                }
45                .expand(),
46            );
47            attrs.extend(
48                FunctionAttr {
49                    name: format!("{}_variadic", self.name),
50                    args: {
51                        let mut args = self.args.clone();
52                        let last = args.last_mut().unwrap();
53                        *last = last.strip_prefix("variadic ").unwrap().into();
54                        args
55                    },
56                    ..self.clone()
57                }
58                .expand(),
59            );
60            return attrs;
61        }
62        let args = self.args.iter().map(|ty| types::expand_type_wildcard(ty));
63        let ret = types::expand_type_wildcard(&self.ret);
64        let mut attrs = Vec::new();
65        for (args, mut ret) in args.multi_cartesian_product().cartesian_product(ret) {
66            if ret == "auto" {
67                ret = types::min_compatible_type(&args);
68            }
69            let attr = FunctionAttr {
70                args: args.iter().map(|s| s.to_string()).collect(),
71                ret: ret.to_owned(),
72                ..self.clone()
73            };
74            attrs.push(attr);
75        }
76        attrs
77    }
78
79    /// Generate the type infer function: `fn(&[DataType]) -> Result<DataType>`
80    fn generate_type_infer_fn(&self) -> Result<TokenStream2> {
81        if let Some(func) = &self.type_infer {
82            if func == "unreachable" {
83                return Ok(
84                    quote! { |_| unreachable!("type inference for this function should be specially handled in frontend, and should not call sig.type_infer") },
85                );
86            }
87            // use the user defined type inference function
88            return Ok(func.parse().unwrap());
89        } else if self.ret == "any" {
90            // TODO: if there are multiple "any", they should be the same type
91            if let Some(i) = self.args.iter().position(|t| t == "any") {
92                // infer as the type of "any" argument
93                return Ok(quote! { |args| Ok(args[#i].clone()) });
94            }
95            if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
96                // infer as the element type of "anyarray" argument
97                return Ok(quote! { |args| Ok(args[#i].as_list_elem().clone()) });
98            }
99        } else if self.ret == "anyarray" {
100            if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
101                // infer as the type of "anyarray" argument
102                return Ok(quote! { |args| Ok(args[#i].clone()) });
103            }
104            if let Some(i) = self.args.iter().position(|t| t == "any") {
105                // infer as the array type of "any" argument
106                return Ok(quote! { |args| Ok(DataType::list(args[#i].clone())) });
107            }
108        } else if self.ret == "struct" {
109            if let Some(i) = self.args.iter().position(|t| t == "struct") {
110                // infer as the type of "struct" argument
111                return Ok(quote! { |args| Ok(args[#i].clone()) });
112            }
113        } else if self.ret == "anymap" {
114            if let Some(i) = self.args.iter().position(|t| t == "anymap") {
115                // infer as the type of "anymap" argument
116                return Ok(quote! { |args| Ok(args[#i].clone()) });
117            }
118        } else {
119            // the return type is fixed
120            let ty = data_type(&self.ret);
121            return Ok(quote! { |_| Ok(#ty) });
122        }
123        Err(Error::new(
124            Span::call_site(),
125            "type inference function cannot be automatically derived. You should provide: `type_infer = \"|args| Ok(...)\"`",
126        ))
127    }
128
129    /// Generate a descriptor (`FuncSign`) of the scalar or table function.
130    ///
131    /// The types of arguments and return value should not contain wildcard.
132    ///
133    /// # Arguments
134    /// `build_fn`: whether the user provided a function is a build function.
135    /// (from the `#[build_function]` macro)
136    pub fn generate_function_descriptor(
137        &self,
138        user_fn: &UserFunctionAttr,
139        build_fn: bool,
140    ) -> Result<TokenStream2> {
141        if self.is_table_function {
142            return self.generate_table_function_descriptor(user_fn, build_fn);
143        }
144        let name = self.name.clone();
145        let variadic = matches!(self.args.last(), Some(t) if t == "...");
146        let args = match variadic {
147            true => &self.args[..self.args.len() - 1],
148            false => &self.args[..],
149        }
150        .iter()
151        .map(|ty| sig_data_type(ty))
152        .collect_vec();
153        let ret = sig_data_type(&self.ret);
154
155        let pb_type = format_ident!("{}", utils::to_camel_case(&name));
156        let ctor_name = format_ident!("{}", self.ident_name());
157        let build_fn = if build_fn {
158            let name = format_ident!("{}", user_fn.name);
159            quote! { #name }
160        } else if self.rewritten {
161            quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
162        } else {
163            // This is the core logic for `#[function]`
164            self.generate_build_scalar_function(user_fn, true)?
165        };
166        let type_infer_fn = self.generate_type_infer_fn()?;
167        let deprecated = self.deprecated;
168        let maybe_allow_deprecated = if deprecated {
169            quote! { #[allow(deprecated)] }
170        } else {
171            quote! {}
172        };
173
174        Ok(quote! {
175            #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
176            fn #ctor_name() -> risingwave_expr::sig::FuncSign {
177                use risingwave_common::types::{DataType, DataTypeName};
178                use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
179
180                #maybe_allow_deprecated
181                FuncSign {
182                    name: risingwave_pb::expr::expr_node::Type::#pb_type.into(),
183                    inputs_type: vec![#(#args),*],
184                    variadic: #variadic,
185                    ret_type: #ret,
186                    build: FuncBuilder::Scalar(#build_fn),
187                    type_infer: #type_infer_fn,
188                    deprecated: #deprecated,
189                }
190            }
191        })
192    }
193
194    /// Generate a build function for the scalar function.
195    ///
196    /// If `optimize_const` is true, the function will be optimized for constant arguments,
197    /// and fallback to the general version if any argument is not constant.
198    fn generate_build_scalar_function(
199        &self,
200        user_fn: &UserFunctionAttr,
201        optimize_const: bool,
202    ) -> Result<TokenStream2> {
203        let variadic = matches!(self.args.last(), Some(t) if t == "...");
204        let num_args = self.args.len() - if variadic { 1 } else { 0 };
205        let fn_name = format_ident!("{}", user_fn.name);
206        let struct_name = match optimize_const {
207            true => format_ident!("{}OptimizeConst", utils::to_camel_case(&self.ident_name())),
208            false => format_ident!("{}", utils::to_camel_case(&self.ident_name())),
209        };
210
211        // we divide all arguments into two groups: prebuilt and non-prebuilt.
212        // prebuilt arguments are collected from the "prebuild" field.
213        // let's say we have a function with 3 arguments: [0, 1, 2]
214        // and the prebuild field contains "$1".
215        // then we have:
216        //     prebuilt_indices = [1]
217        //     non_prebuilt_indices = [0, 2]
218        //
219        // if the const argument optimization is enabled, prebuilt arguments are
220        // evaluated at build time, thus the children only contain non-prebuilt arguments:
221        //     children_indices = [0, 2]
222        // otherwise, the children contain all arguments:
223        //     children_indices = [0, 1, 2]
224
225        let prebuilt_indices = match &self.prebuild {
226            Some(s) => (0..num_args)
227                .filter(|i| s.contains(&format!("${i}")))
228                .collect_vec(),
229            None => vec![],
230        };
231        let non_prebuilt_indices = match &self.prebuild {
232            Some(s) => (0..num_args)
233                .filter(|i| !s.contains(&format!("${i}")))
234                .collect_vec(),
235            _ => (0..num_args).collect_vec(),
236        };
237        let children_indices = match optimize_const {
238            #[allow(clippy::redundant_clone)] // false-positive
239            true => non_prebuilt_indices.clone(),
240            false => (0..num_args).collect_vec(),
241        };
242
243        /// Return a list of identifiers with the given prefix and indices.
244        fn idents(prefix: &str, indices: &[usize]) -> Vec<Ident> {
245            indices
246                .iter()
247                .map(|i| format_ident!("{prefix}{i}"))
248                .collect()
249        }
250        let inputs = idents("i", &children_indices);
251        let prebuilt_inputs = idents("i", &prebuilt_indices);
252        let non_prebuilt_inputs = idents("i", &non_prebuilt_indices);
253        let array_refs = idents("array", &children_indices);
254        let arrays = idents("a", &children_indices);
255        let datums = idents("v", &children_indices);
256        let arg_arrays = children_indices
257            .iter()
258            .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
259        let arg_types = children_indices.iter().map(|i| {
260            types::ref_type(&self.args[*i])
261                .parse::<TokenStream2>()
262                .unwrap()
263        });
264        let annotation: TokenStream2 = match user_fn.core_return_type.as_str() {
265            // add type annotation for functions that return generic types
266            "T" | "T1" | "T2" | "T3" => format!(": Option<{}>", types::owned_type(&self.ret))
267                .parse()
268                .unwrap(),
269            _ => quote! {},
270        };
271        let ret_array_type = format_ident!("{}", types::array_type(&self.ret));
272        let builder_type = format_ident!("{}Builder", types::array_type(&self.ret));
273        let prebuilt_arg_type = match &self.prebuild {
274            Some(s) if optimize_const => s.split("::").next().unwrap().parse().unwrap(),
275            _ => quote! { () },
276        };
277        let prebuilt_arg_value = match &self.prebuild {
278            // example:
279            // prebuild = "RegexContext::new($1)"
280            // return = "RegexContext::new(i1)"
281            Some(s) => s
282                .replace('$', "i")
283                .parse()
284                .expect("invalid prebuild syntax"),
285            None => quote! { () },
286        };
287        let prebuild_const = if self.prebuild.is_some() && optimize_const {
288            let build_general = self.generate_build_scalar_function(user_fn, false)?;
289            quote! {{
290                let build_general = #build_general;
291                #(
292                    // try to evaluate constant for prebuilt arguments
293                    let #prebuilt_inputs = match children[#prebuilt_indices].eval_const() {
294                        Ok(s) => s,
295                        // prebuilt argument is not constant, fallback to general
296                        Err(_) => return build_general(return_type, children),
297                    };
298                    // get reference to the constant value
299                    let #prebuilt_inputs = match &#prebuilt_inputs {
300                        Some(s) => s.as_scalar_ref_impl().try_into()?,
301                        // the function should always return null if any const argument is null
302                        None => return Ok(Box::new(risingwave_expr::expr::LiteralExpression::new(
303                            return_type,
304                            None,
305                        ))),
306                    };
307                )*
308                #prebuilt_arg_value
309            }}
310        } else {
311            quote! { () }
312        };
313
314        // ensure the number of children matches the number of arguments
315        let check_children = match variadic {
316            true => quote! { risingwave_expr::ensure!(children.len() >= #num_args); },
317            false => quote! { risingwave_expr::ensure!(children.len() == #num_args); },
318        };
319
320        // evaluate variadic arguments in `eval`
321        let eval_variadic = variadic.then(|| {
322            quote! {
323                let mut columns = Vec::with_capacity(self.children.len() - #num_args);
324                for child in &self.children[#num_args..] {
325                    columns.push(child.eval(input).await?);
326                }
327                let variadic_input = DataChunk::new(columns, input.visibility().clone());
328            }
329        });
330        // evaluate variadic arguments in `eval_row`
331        let eval_row_variadic = variadic.then(|| {
332            quote! {
333                let mut row = Vec::with_capacity(self.children.len() - #num_args);
334                for child in &self.children[#num_args..] {
335                    row.push(child.eval_row(input).await?);
336                }
337                let variadic_row = OwnedRow::new(row);
338            }
339        });
340
341        let generic = (self.ret == "boolean" && user_fn.generic == 3).then(|| {
342            // XXX: for generic compare functions, we need to specify the compatible type
343            let compatible_type = types::ref_type(types::min_compatible_type(&self.args))
344                .parse::<TokenStream2>()
345                .unwrap();
346            quote! { ::<_, _, #compatible_type> }
347        });
348        let prebuilt_arg = match (&self.prebuild, optimize_const) {
349            // use the prebuilt argument
350            (Some(_), true) => quote! { &self.prebuilt_arg, },
351            // build the argument on site
352            (Some(_), false) => quote! { &#prebuilt_arg_value, },
353            // no prebuilt argument
354            (None, _) => quote! {},
355        };
356        let variadic_args = variadic.then(|| quote! { &variadic_row, });
357        let context = user_fn.context.then(|| quote! { &self.context, });
358        let writer = user_fn
359            .writer_type_kind
360            .is_some()
361            .then(|| quote! { &mut writer, });
362        let await_ = user_fn.async_.then(|| quote! { .await });
363
364        let record_error = {
365            // Uniform arguments into `DatumRef`.
366            #[allow(clippy::disallowed_methods)] // allow zip
367            let inputs_args = inputs
368                .iter()
369                .zip(user_fn.args_option.iter())
370                .map(|(input, opt)| {
371                    if *opt {
372                        quote! { #input.map(|s| ScalarRefImpl::from(s)) }
373                    } else {
374                        quote! { Some(ScalarRefImpl::from(#input)) }
375                    }
376                });
377            let inputs_args = quote! {
378                let args: &[DatumRef<'_>] = &[#(#inputs_args),*];
379                let args = args.iter().copied();
380            };
381            let var_args = variadic.then(|| {
382                quote! {
383                    let args = args.chain(variadic_row.iter());
384                }
385            });
386
387            quote! {
388                #inputs_args
389                #var_args
390                errors.push(ExprError::function(
391                    stringify!(#fn_name),
392                    args,
393                    e,
394                ));
395            }
396        };
397
398        // call the user defined function
399        // inputs: [ Option<impl ScalarRef> ]
400        let mut output = quote! { #fn_name #generic(
401            #(#non_prebuilt_inputs,)*
402            #variadic_args
403            #prebuilt_arg
404            #context
405            #writer
406        ) #await_ };
407        // handle error if the function returns `Result`
408        // wrap a `Some` if the function doesn't return `Option`
409        output = match user_fn.return_type_kind {
410            // XXX: we don't support void type yet. return null::int for now.
411            _ if self.ret == "void" => quote! { { #output; Option::<i32>::None } },
412            ReturnTypeKind::T => quote! { Some(#output) },
413            ReturnTypeKind::Option => output,
414            ReturnTypeKind::Result => quote! {
415                match #output {
416                    Ok(x) => Some(x),
417                    Err(e) => {
418                        #record_error
419                        None
420                    }
421                }
422            },
423            ReturnTypeKind::ResultOption => quote! {
424                match #output {
425                    Ok(x) => x,
426                    Err(e) => {
427                        #record_error
428                        None
429                    }
430                }
431            },
432        };
433        // if user function accepts non-option arguments, we assume the function
434        // returns null on null input, so we need to unwrap the inputs before calling.
435        if self.prebuild.is_some() {
436            output = quote! {
437                match (#(#inputs,)*) {
438                    (#(Some(#inputs),)*) => #output,
439                    _ => None,
440                }
441            };
442        } else {
443            #[allow(clippy::disallowed_methods)] // allow zip
444            let some_inputs = inputs
445                .iter()
446                .zip(user_fn.args_option.iter())
447                .map(|(input, opt)| {
448                    if *opt {
449                        quote! { #input }
450                    } else {
451                        quote! { Some(#input) }
452                    }
453                });
454            output = quote! {
455                match (#(#inputs,)*) {
456                    (#(#some_inputs,)*) => #output,
457                    _ => None,
458                }
459            };
460        };
461        // now the `output` is: Option<impl ScalarRef or Scalar>
462        let append_output = match user_fn.writer_type_kind {
463            Some(WriterTypeKind::FmtWrite)
464            | Some(WriterTypeKind::IoWrite)
465            | Some(WriterTypeKind::ListWrite) => quote! {{
466                let mut writer = builder.writer();
467                if #output.is_some() {
468                    writer.finish();
469                } else {
470                    writer.rollback();
471                    builder.append_null();
472                }
473            }},
474            Some(WriterTypeKind::JsonbbBuilder) => quote! {{
475                let mut writer_wrapper = builder.writer();
476                let mut writer = writer_wrapper.inner();
477                if #output.is_some() {
478                    writer_wrapper.finish();
479                } else {
480                    writer_wrapper.rollback();
481                    builder.append_null();
482                }
483            }},
484            None if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
485                builder.append(#output.as_ref().map(|s| s.as_ref()));
486            },
487            None => quote! {
488                let output #annotation = #output;
489                builder.append(output.as_ref().map(|s| s.as_scalar_ref()));
490            },
491        };
492        // the output expression in `eval_row`
493        let row_output = match user_fn.writer_type_kind {
494            Some(WriterTypeKind::FmtWrite) => quote! {{
495                let mut writer = String::new();
496                #output.map(|_| writer.into())
497            }},
498            Some(WriterTypeKind::IoWrite) => quote! {{
499                let mut writer = Vec::new();
500                #output.map(|_| writer.into())
501            }},
502            Some(WriterTypeKind::JsonbbBuilder) => quote! {{
503                let mut writer = jsonbb::Builder::<Vec<u8>>::new();
504                #output.map(|_| JsonbVal::from(writer.finish()).into())
505            }},
506            Some(WriterTypeKind::ListWrite) => quote! {{
507                let mut writer = {
508                    let DataType::List(list_ty) = &self.context.return_type else {
509                        panic!("data type must be DataType::List");
510                    };
511                    list_ty.elem().create_array_builder(1)
512                };
513                #output.map(|_| ListValue::new(writer.finish()).into())
514            }},
515            None if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
516                #output.map(|s| s.as_ref().into())
517            },
518            None => quote! {{
519                let output #annotation = #output;
520                output.map(|s| s.into())
521            }},
522        };
523        // the main body in `eval`
524        let eval = if let Some(batch_fn) = &self.batch_fn {
525            assert!(
526                !variadic,
527                "customized batch function is not supported for variadic functions"
528            );
529            // user defined batch function
530            let fn_name = format_ident!("{}", batch_fn);
531            quote! {
532                let c = #fn_name(#(#arrays),*);
533                Arc::new(c.into())
534            }
535        } else if (types::is_primitive(&self.ret) || self.ret == "boolean")
536            && user_fn.is_pure()
537            && !variadic
538            && self.prebuild.is_none()
539        {
540            // SIMD optimization for primitive types
541            match self.args.len() {
542                0 => quote! {
543                    let c = #ret_array_type::from_iter_bitmap(
544                        std::iter::repeat_with(|| #fn_name()).take(input.capacity()),
545                        Bitmap::ones(input.capacity()),
546                    );
547                    Arc::new(c.into())
548                },
549                1 => quote! {
550                    let c = #ret_array_type::from_iter_bitmap(
551                        a0.raw_iter().map(|a| #fn_name(a)),
552                        a0.null_bitmap().clone()
553                    );
554                    Arc::new(c.into())
555                },
556                2 => quote! {
557                    // allow using `zip` for performance
558                    #[allow(clippy::disallowed_methods)]
559                    let c = #ret_array_type::from_iter_bitmap(
560                        a0.raw_iter()
561                            .zip(a1.raw_iter())
562                            .map(|(a, b)| #fn_name #generic(a, b)),
563                        a0.null_bitmap() & a1.null_bitmap(),
564                    );
565                    Arc::new(c.into())
566                },
567                n => todo!("SIMD optimization for {n} arguments"),
568            }
569        } else {
570            // no optimization
571            let let_variadic = variadic.then(|| {
572                quote! {
573                    let variadic_row = variadic_input.row_at_unchecked_vis(i);
574                }
575            });
576            quote! {
577                let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone());
578
579                if input.is_vis_compacted() {
580                    for i in 0..input.capacity() {
581                        #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
582                        #let_variadic
583                        #append_output
584                    }
585                } else {
586                    for i in 0..input.capacity() {
587                        if unsafe { !input.visibility().is_set_unchecked(i) } {
588                            builder.append_null();
589                            continue;
590                        }
591                        #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
592                        #let_variadic
593                        #append_output
594                    }
595                }
596                Arc::new(builder.finish().into())
597            }
598        };
599
600        Ok(quote! {
601            |return_type: DataType, children: Vec<risingwave_expr::expr::BoxedExpression>|
602                -> risingwave_expr::Result<risingwave_expr::expr::BoxedExpression>
603            {
604                use std::sync::Arc;
605                use risingwave_common::array::*;
606                use risingwave_common::types::*;
607                use risingwave_common::bitmap::Bitmap;
608                use risingwave_common::row::OwnedRow;
609                use risingwave_common::util::iter_util::ZipEqFast;
610
611                use risingwave_expr::expr::{Context, BoxedExpression};
612                use risingwave_expr::{ExprError, Result};
613                use risingwave_expr::codegen::*;
614
615                #check_children
616                let prebuilt_arg = #prebuild_const;
617                let context = Context {
618                    return_type,
619                    arg_types: children.iter().map(|c| c.return_type()).collect(),
620                    variadic: #variadic,
621                };
622
623                #[derive(Debug)]
624                struct #struct_name {
625                    context: Context,
626                    children: Vec<BoxedExpression>,
627                    prebuilt_arg: #prebuilt_arg_type,
628                }
629                #[async_trait]
630                impl risingwave_expr::expr::Expression for #struct_name {
631                    fn return_type(&self) -> DataType {
632                        self.context.return_type.clone()
633                    }
634                    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
635                        #(
636                            let #array_refs = self.children[#children_indices].eval(input).await?;
637                            let #arrays: &#arg_arrays = #array_refs.as_ref().into();
638                        )*
639                        #eval_variadic
640                        let mut errors = vec![];
641                        let array = { #eval };
642                        if errors.is_empty() {
643                            Ok(array)
644                        } else {
645                            Err(ExprError::Multiple(array, errors.into()))
646                        }
647                    }
648                    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
649                        #(
650                            let #datums = self.children[#children_indices].eval_row(input).await?;
651                            let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap());
652                        )*
653                        #eval_row_variadic
654                        let mut errors: Vec<ExprError> = vec![];
655                        let output = #row_output;
656                        if let Some(err) = errors.into_iter().next() {
657                            Err(err.into())
658                        } else {
659                            Ok(output)
660                        }
661                    }
662                }
663
664                Ok(Box::new(#struct_name {
665                    context,
666                    children,
667                    prebuilt_arg,
668                }))
669            }
670        })
671    }
672
673    /// Generate a descriptor of the aggregate function.
674    ///
675    /// The types of arguments and return value should not contain wildcard.
676    /// `user_fn` could be either `fn` or `impl`.
677    /// If `build_fn` is true, `user_fn` must be a `fn` that builds the aggregate function.
678    pub fn generate_aggregate_descriptor(
679        &self,
680        user_fn: &AggregateFnOrImpl,
681        build_fn: bool,
682    ) -> Result<TokenStream2> {
683        let name = self.name.clone();
684
685        let mut args = Vec::with_capacity(self.args.len());
686        for ty in &self.args {
687            args.push(sig_data_type(ty));
688        }
689        let ret = sig_data_type(&self.ret);
690        let state_type = match &self.state {
691            Some(ty) if ty != "ref" => {
692                let ty = data_type(ty);
693                quote! { Some(#ty) }
694            }
695            _ => quote! { None },
696        };
697        let append_only = match build_fn {
698            false => !user_fn.has_retract(),
699            true => self.append_only,
700        };
701
702        let pb_kind = format_ident!("{}", utils::to_camel_case(&name));
703        let ctor_name = match append_only {
704            false => format_ident!("{}", self.ident_name()),
705            true => format_ident!("{}_append_only", self.ident_name()),
706        };
707        let build_fn = if build_fn {
708            let name = format_ident!("{}", user_fn.as_fn().name);
709            quote! { #name }
710        } else if self.rewritten {
711            quote! { |_| Err(ExprError::UnsupportedFunction(#name.into())) }
712        } else {
713            self.generate_agg_build_fn(user_fn)?
714        };
715        let build_retractable = match append_only {
716            true => quote! { None },
717            false => quote! { Some(#build_fn) },
718        };
719        let build_append_only = match append_only {
720            false => quote! { None },
721            true => quote! { Some(#build_fn) },
722        };
723        let retractable_state_type = match append_only {
724            true => quote! { None },
725            false => state_type.clone(),
726        };
727        let append_only_state_type = match append_only {
728            false => quote! { None },
729            true => state_type,
730        };
731        let type_infer_fn = self.generate_type_infer_fn()?;
732        let deprecated = self.deprecated;
733        let maybe_allow_deprecated = if deprecated {
734            quote! { #[allow(deprecated)] }
735        } else {
736            quote! {}
737        };
738
739        Ok(quote! {
740            #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
741            fn #ctor_name() -> risingwave_expr::sig::FuncSign {
742                use risingwave_common::types::{DataType, DataTypeName};
743                use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
744
745                #maybe_allow_deprecated
746                FuncSign {
747                    name: risingwave_pb::expr::agg_call::PbKind::#pb_kind.into(),
748                    inputs_type: vec![#(#args),*],
749                    variadic: false,
750                    ret_type: #ret,
751                    build: FuncBuilder::Aggregate {
752                        retractable: #build_retractable,
753                        append_only: #build_append_only,
754                        retractable_state_type: #retractable_state_type,
755                        append_only_state_type: #append_only_state_type,
756                    },
757                    type_infer: #type_infer_fn,
758                    deprecated: #deprecated,
759                }
760            }
761        })
762    }
763
764    /// Generate build function for aggregate function.
765    fn generate_agg_build_fn(&self, user_fn: &AggregateFnOrImpl) -> Result<TokenStream2> {
766        // If the first argument of the aggregate function is of type `&mut T`,
767        // we assume it is a user defined state type.
768        let custom_state = user_fn.accumulate().first_mut_ref_arg.as_ref();
769        let state_type: TokenStream2 = match (custom_state, &self.state) {
770            (Some(s), _) => s.parse().unwrap(),
771            (_, Some(state)) if state == "ref" => types::ref_type(&self.ret).parse().unwrap(),
772            (_, Some(state)) if state != "ref" => types::owned_type(state).parse().unwrap(),
773            _ => types::owned_type(&self.ret).parse().unwrap(),
774        };
775        let let_arrays = self
776            .args
777            .iter()
778            .enumerate()
779            .map(|(i, arg)| {
780                let array = format_ident!("a{i}");
781                let array_type: TokenStream2 = types::array_type(arg).parse().unwrap();
782                quote! {
783                    let #array: &#array_type = input.column_at(#i).as_ref().into();
784                }
785            })
786            .collect_vec();
787        let let_values = (0..self.args.len())
788            .map(|i| {
789                let v = format_ident!("v{i}");
790                let a = format_ident!("a{i}");
791                quote! { let #v = unsafe { #a.value_at_unchecked(row_id) }; }
792            })
793            .collect_vec();
794        let downcast_state = if custom_state.is_some() {
795            quote! { let mut state: &mut #state_type = state0.downcast_mut(); }
796        } else if let Some(s) = &self.state
797            && s == "ref"
798        {
799            quote! { let mut state: Option<#state_type> = state0.as_datum_mut().as_ref().map(|x| x.as_scalar_ref_impl().try_into().unwrap()); }
800        } else {
801            quote! { let mut state: Option<#state_type> = state0.as_datum_mut().take().map(|s| s.try_into().unwrap()); }
802        };
803        let restore_state = if custom_state.is_some() {
804            quote! {}
805        } else if let Some(s) = &self.state
806            && s == "ref"
807        {
808            quote! { *state0.as_datum_mut() = state.map(|x| x.to_owned_scalar().into()); }
809        } else {
810            quote! { *state0.as_datum_mut() = state.map(|s| s.into()); }
811        };
812        let create_state = if custom_state.is_some() {
813            quote! {
814                fn create_state(&self) -> Result<AggregateState> {
815                    Ok(AggregateState::Any(Box::<#state_type>::default()))
816                }
817            }
818        } else if let Some(state) = &self.init_state {
819            let state: TokenStream2 = state.parse().unwrap();
820            quote! {
821                fn create_state(&self) -> Result<AggregateState> {
822                    Ok(AggregateState::Datum(Some(#state.into())))
823                }
824            }
825        } else {
826            // by default: `AggregateState::Datum(None)`
827            quote! {}
828        };
829        let args = (0..self.args.len()).map(|i| format_ident!("v{i}"));
830        let args = quote! { #(#args,)* };
831        let panic_on_retract = {
832            let msg = format!(
833                "attempt to retract on aggregate function {}, but it is append-only",
834                self.name
835            );
836            quote! { assert_eq!(op, Op::Insert, #msg); }
837        };
838        let mut next_state = match user_fn {
839            AggregateFnOrImpl::Fn(f) => {
840                let context = f.context.then(|| quote! { &self.context, });
841                let fn_name = format_ident!("{}", f.name);
842                match f.retract {
843                    true => {
844                        quote! { #fn_name(state, #args matches!(op, Op::Delete | Op::UpdateDelete) #context) }
845                    }
846                    false => quote! {{
847                        #panic_on_retract
848                        #fn_name(state, #args #context)
849                    }},
850                }
851            }
852            AggregateFnOrImpl::Impl(i) => {
853                let retract = match i.retract {
854                    Some(_) => quote! { self.function.retract(state, #args) },
855                    None => panic_on_retract,
856                };
857                quote! {
858                    if matches!(op, Op::Delete | Op::UpdateDelete) {
859                        #retract
860                    } else {
861                        self.function.accumulate(state, #args)
862                    }
863                }
864            }
865        };
866        next_state = match user_fn.accumulate().return_type_kind {
867            ReturnTypeKind::T => quote! { Some(#next_state) },
868            ReturnTypeKind::Option => next_state,
869            ReturnTypeKind::Result => quote! { Some(#next_state?) },
870            ReturnTypeKind::ResultOption => quote! { #next_state? },
871        };
872        if user_fn.accumulate().args_option.iter().all(|b| !b) {
873            match self.args.len() {
874                0 => {
875                    next_state = quote! {
876                        match state {
877                            Some(state) => #next_state,
878                            None => state,
879                        }
880                    };
881                }
882                1 => {
883                    let first_state = if self.init_state.is_some() {
884                        // for count, the state will never be None
885                        quote! { unreachable!() }
886                    } else if let Some(s) = &self.state
887                        && s == "ref"
888                    {
889                        // for min/max/first/last, the state is the first value
890                        quote! { Some(v0) }
891                    } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
892                        && impl_.create_state.is_some()
893                    {
894                        // use user-defined create_state function
895                        quote! {{
896                            let state = self.function.create_state();
897                            #next_state
898                        }}
899                    } else {
900                        quote! {{
901                            let state = #state_type::default();
902                            #next_state
903                        }}
904                    };
905                    next_state = quote! {
906                        match (state, v0) {
907                            (Some(state), Some(v0)) => #next_state,
908                            (None, Some(v0)) => #first_state,
909                            (state, None) => state,
910                        }
911                    };
912                }
913                _ => todo!("multiple arguments are not supported for non-option function"),
914            }
915        }
916        let update_state = if custom_state.is_some() {
917            quote! { _ = #next_state; }
918        } else {
919            quote! { state = #next_state; }
920        };
921        let get_result = if custom_state.is_some() {
922            quote! { Ok(state.downcast_ref::<#state_type>().into()) }
923        } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
924            && impl_.finalize.is_some()
925        {
926            quote! {
927                let state = match state.as_datum() {
928                    Some(s) => s.as_scalar_ref_impl().try_into().unwrap(),
929                    None => return Ok(None),
930                };
931                Ok(Some(self.function.finalize(state).into()))
932            }
933        } else {
934            quote! { Ok(state.as_datum().clone()) }
935        };
936        let function_field = match user_fn {
937            AggregateFnOrImpl::Fn(_) => quote! {},
938            AggregateFnOrImpl::Impl(i) => {
939                let struct_name = format_ident!("{}", i.struct_name);
940                let generic = self.generic.as_ref().map(|g| {
941                    let g = format_ident!("{g}");
942                    quote! { <#g> }
943                });
944                quote! { function: #struct_name #generic, }
945            }
946        };
947        let function_new = match user_fn {
948            AggregateFnOrImpl::Fn(_) => quote! {},
949            AggregateFnOrImpl::Impl(i) => {
950                let struct_name = format_ident!("{}", i.struct_name);
951                let generic = self.generic.as_ref().map(|g| {
952                    let g = format_ident!("{g}");
953                    quote! { ::<#g> }
954                });
955                quote! { function: #struct_name #generic :: default(), }
956            }
957        };
958
959        Ok(quote! {
960            |agg| {
961                use std::collections::HashSet;
962                use std::ops::Range;
963                use risingwave_common::array::*;
964                use risingwave_common::types::*;
965                use risingwave_common::bail;
966                use risingwave_common::bitmap::Bitmap;
967                use risingwave_common_estimate_size::EstimateSize;
968
969                use risingwave_expr::expr::Context;
970                use risingwave_expr::Result;
971                use risingwave_expr::aggregate::AggregateState;
972                use risingwave_expr::codegen::async_trait;
973
974                let context = Context {
975                    return_type: agg.return_type.clone(),
976                    arg_types: agg.args.arg_types().to_owned(),
977                    variadic: false,
978                };
979
980                struct Agg {
981                    context: Context,
982                    #function_field
983                }
984
985                #[async_trait]
986                impl risingwave_expr::aggregate::AggregateFunction for Agg {
987                    fn return_type(&self) -> DataType {
988                        self.context.return_type.clone()
989                    }
990
991                    #create_state
992
993                    async fn update(&self, state0: &mut AggregateState, input: &StreamChunk) -> Result<()> {
994                        #(#let_arrays)*
995                        #downcast_state
996                        for row_id in input.visibility().iter_ones() {
997                            let op = unsafe { *input.ops().get_unchecked(row_id) };
998                            #(#let_values)*
999                            #update_state
1000                        }
1001                        #restore_state
1002                        Ok(())
1003                    }
1004
1005                    async fn update_range(&self, state0: &mut AggregateState, input: &StreamChunk, range: Range<usize>) -> Result<()> {
1006                        assert!(range.end <= input.capacity());
1007                        #(#let_arrays)*
1008                        #downcast_state
1009                        if input.is_vis_compacted() {
1010                            for row_id in range {
1011                                let op = unsafe { *input.ops().get_unchecked(row_id) };
1012                                #(#let_values)*
1013                                #update_state
1014                            }
1015                        } else {
1016                            for row_id in input.visibility().iter_ones() {
1017                                if row_id < range.start {
1018                                    continue;
1019                                } else if row_id >= range.end {
1020                                    break;
1021                                }
1022                                let op = unsafe { *input.ops().get_unchecked(row_id) };
1023                                #(#let_values)*
1024                                #update_state
1025                            }
1026                        }
1027                        #restore_state
1028                        Ok(())
1029                    }
1030
1031                    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
1032                        #get_result
1033                    }
1034                }
1035
1036                Ok(Box::new(Agg {
1037                    context,
1038                    #function_new
1039                }))
1040            }
1041        })
1042    }
1043
1044    /// Generate a descriptor of the table function.
1045    ///
1046    /// The types of arguments and return value should not contain wildcard.
1047    fn generate_table_function_descriptor(
1048        &self,
1049        user_fn: &UserFunctionAttr,
1050        build_fn: bool,
1051    ) -> Result<TokenStream2> {
1052        let name = self.name.clone();
1053        let mut args = Vec::with_capacity(self.args.len());
1054        for ty in &self.args {
1055            args.push(sig_data_type(ty));
1056        }
1057        let ret = sig_data_type(&self.ret);
1058
1059        let pb_type = format_ident!("{}", utils::to_camel_case(&name));
1060        let ctor_name = format_ident!("{}", self.ident_name());
1061        let build_fn = if build_fn {
1062            let name = format_ident!("{}", user_fn.name);
1063            quote! { #name }
1064        } else if self.rewritten {
1065            quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
1066        } else {
1067            self.generate_build_table_function(user_fn)?
1068        };
1069        let type_infer_fn = self.generate_type_infer_fn()?;
1070        let deprecated = self.deprecated;
1071        let maybe_allow_deprecated = if deprecated {
1072            quote! { #[allow(deprecated)] }
1073        } else {
1074            quote! {}
1075        };
1076
1077        Ok(quote! {
1078            #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
1079            fn #ctor_name() -> risingwave_expr::sig::FuncSign {
1080                use risingwave_common::types::{DataType, DataTypeName};
1081                use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
1082
1083                #maybe_allow_deprecated
1084                FuncSign {
1085                    name: risingwave_pb::expr::table_function::Type::#pb_type.into(),
1086                    inputs_type: vec![#(#args),*],
1087                    variadic: false,
1088                    ret_type: #ret,
1089                    build: FuncBuilder::Table(#build_fn),
1090                    type_infer: #type_infer_fn,
1091                    deprecated: #deprecated,
1092                }
1093            }
1094        })
1095    }
1096
1097    fn generate_build_table_function(&self, user_fn: &UserFunctionAttr) -> Result<TokenStream2> {
1098        let num_args = self.args.len();
1099        let return_types = output_types(&self.ret);
1100        let fn_name = format_ident!("{}", user_fn.name);
1101        let struct_name = format_ident!("{}", utils::to_camel_case(&self.ident_name()));
1102        let arg_ids = (0..num_args)
1103            .filter(|i| match &self.prebuild {
1104                Some(s) => !s.contains(&format!("${i}")),
1105                None => true,
1106            })
1107            .collect_vec();
1108        let const_ids = (0..num_args).filter(|i| match &self.prebuild {
1109            Some(s) => s.contains(&format!("${i}")),
1110            None => false,
1111        });
1112        let inputs: Vec<_> = arg_ids.iter().map(|i| format_ident!("i{i}")).collect();
1113        let all_child: Vec<_> = (0..num_args).map(|i| format_ident!("child{i}")).collect();
1114        let const_child: Vec<_> = const_ids.map(|i| format_ident!("child{i}")).collect();
1115        let child: Vec<_> = arg_ids.iter().map(|i| format_ident!("child{i}")).collect();
1116        let array_refs: Vec<_> = arg_ids.iter().map(|i| format_ident!("array{i}")).collect();
1117        let arrays: Vec<_> = arg_ids.iter().map(|i| format_ident!("a{i}")).collect();
1118        let arg_arrays = arg_ids
1119            .iter()
1120            .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
1121        let outputs = (0..return_types.len())
1122            .map(|i| format_ident!("o{i}"))
1123            .collect_vec();
1124        let builders = (0..return_types.len())
1125            .map(|i| format_ident!("builder{i}"))
1126            .collect_vec();
1127        let builder_types = return_types
1128            .iter()
1129            .map(|ty| format_ident!("{}Builder", types::array_type(ty)))
1130            .collect_vec();
1131        let return_types = if return_types.len() == 1 {
1132            vec![quote! { self.context.return_type.clone() }]
1133        } else {
1134            (0..return_types.len())
1135                .map(|i| quote! { self.context.return_type.as_struct().types().nth(#i).unwrap().clone() })
1136                .collect()
1137        };
1138        #[allow(clippy::disallowed_methods)]
1139        let optioned_outputs = user_fn
1140            .core_return_type
1141            .split(',')
1142            .map(|t| t.contains("Option"))
1143            // example: "(Option<&str>, i32)" => [true, false]
1144            .zip(&outputs)
1145            .map(|(optional, o)| match optional {
1146                false => quote! { Some(#o.as_scalar_ref()) },
1147                true => quote! { #o.map(|o| o.as_scalar_ref()) },
1148            })
1149            .collect_vec();
1150        let build_value_array = if return_types.len() == 1 {
1151            quote! { let [value_array] = value_arrays; }
1152        } else {
1153            quote! {
1154                let value_array = StructArray::new(
1155                    self.context.return_type.as_struct().clone(),
1156                    value_arrays.to_vec(),
1157                    Bitmap::ones(len),
1158                ).into_ref();
1159            }
1160        };
1161        let context = user_fn.context.then(|| quote! { &self.context, });
1162        let prebuilt_arg = match &self.prebuild {
1163            Some(_) => quote! { &self.prebuilt_arg, },
1164            None => quote! {},
1165        };
1166        let prebuilt_arg_type = match &self.prebuild {
1167            Some(s) => s.split("::").next().unwrap().parse().unwrap(),
1168            None => quote! { () },
1169        };
1170        let prebuilt_arg_value = match &self.prebuild {
1171            Some(s) => s
1172                .replace('$', "child")
1173                .parse()
1174                .expect("invalid prebuild syntax"),
1175            None => quote! { () },
1176        };
1177        let iter = quote! { #fn_name(#(#inputs,)* #prebuilt_arg #context) };
1178        let mut iter = match user_fn.return_type_kind {
1179            ReturnTypeKind::T => quote! { #iter },
1180            ReturnTypeKind::Option => quote! { match #iter {
1181                Some(it) => it,
1182                None => continue,
1183            } },
1184            ReturnTypeKind::Result => quote! { match #iter {
1185                Ok(it) => it,
1186                Err(e) => {
1187                    index_builder.append(Some(i as i32));
1188                    #(#builders.append_null();)*
1189                    error_builder.append_display(Some(e.as_report()));
1190                    continue;
1191                }
1192            } },
1193            ReturnTypeKind::ResultOption => quote! { match #iter {
1194                Ok(Some(it)) => it,
1195                Ok(None) => continue,
1196                Err(e) => {
1197                    index_builder.append(Some(i as i32));
1198                    #(#builders.append_null();)*
1199                    error_builder.append_display(Some(e.as_report()));
1200                    continue;
1201                }
1202            } },
1203        };
1204        // if user function accepts non-option arguments, we assume the function
1205        // returns empty on null input, so we need to unwrap the inputs before calling.
1206        #[allow(clippy::disallowed_methods)] // allow zip
1207        let some_inputs = inputs
1208            .iter()
1209            .zip(user_fn.args_option.iter())
1210            .map(|(input, opt)| {
1211                if *opt {
1212                    quote! { #input }
1213                } else {
1214                    quote! { Some(#input) }
1215                }
1216            });
1217        iter = quote! {
1218            match (#(#inputs,)*) {
1219                (#(#some_inputs,)*) => #iter,
1220                _ => continue,
1221            }
1222        };
1223        let iterator_item_type = user_fn.iterator_item_kind.clone().ok_or_else(|| {
1224            Error::new(
1225                user_fn.return_type_span,
1226                "expect `impl Iterator` in return type",
1227            )
1228        })?;
1229        let append_output = match iterator_item_type {
1230            ReturnTypeKind::T => quote! {
1231                let (#(#outputs),*) = output;
1232                #(#builders.append(#optioned_outputs);)* error_builder.append_null();
1233            },
1234            ReturnTypeKind::Option => quote! { match output {
1235                Some((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1236                None => { #(#builders.append_null();)* error_builder.append_null(); }
1237            } },
1238            ReturnTypeKind::Result => quote! { match output {
1239                Ok((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1240                Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1241            } },
1242            ReturnTypeKind::ResultOption => quote! { match output {
1243                Ok(Some((#(#outputs),*))) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1244                Ok(None) => { #(#builders.append_null();)* error_builder.append_null(); }
1245                Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1246            } },
1247        };
1248
1249        Ok(quote! {
1250            |return_type, chunk_size, children| {
1251                use risingwave_common::array::*;
1252                use risingwave_common::types::*;
1253                use risingwave_common::bitmap::Bitmap;
1254                use risingwave_common::util::iter_util::ZipEqFast;
1255                use risingwave_expr::expr::{BoxedExpression, Context};
1256                use risingwave_expr::{Result, ExprError};
1257                use risingwave_expr::codegen::*;
1258
1259                risingwave_expr::ensure!(children.len() == #num_args);
1260
1261                let context = Context {
1262                    return_type: return_type.clone(),
1263                    arg_types: children.iter().map(|c| c.return_type()).collect(),
1264                    variadic: false,
1265                };
1266
1267                let mut iter = children.into_iter();
1268                #(let #all_child = iter.next().unwrap();)*
1269                #(
1270                    let #const_child = #const_child.eval_const()?;
1271                    let #const_child = match &#const_child {
1272                        Some(s) => s.as_scalar_ref_impl().try_into()?,
1273                        // the function should always return empty if any const argument is null
1274                        None => return Ok(risingwave_expr::table_function::empty(return_type)),
1275                    };
1276                )*
1277
1278                #[derive(Debug)]
1279                struct #struct_name {
1280                    context: Context,
1281                    chunk_size: usize,
1282                    #(#child: BoxedExpression,)*
1283                    prebuilt_arg: #prebuilt_arg_type,
1284                }
1285                #[async_trait]
1286                impl risingwave_expr::table_function::TableFunction for #struct_name {
1287                    fn return_type(&self) -> DataType {
1288                        self.context.return_type.clone()
1289                    }
1290                    async fn eval<'a>(&'a self, input: &'a DataChunk) -> BoxStream<'a, Result<DataChunk>> {
1291                        self.eval_inner(input)
1292                    }
1293                }
1294                impl #struct_name {
1295                    #[try_stream(boxed, ok = DataChunk, error = ExprError)]
1296                    async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
1297                        #(
1298                        let #array_refs = self.#child.eval(input).await?;
1299                        let #arrays: &#arg_arrays = #array_refs.as_ref().into();
1300                        )*
1301
1302                        let mut index_builder = I32ArrayBuilder::new(self.chunk_size);
1303                        #(let mut #builders = #builder_types::with_type(self.chunk_size, #return_types);)*
1304                        let mut error_builder = Utf8ArrayBuilder::new(self.chunk_size);
1305
1306                        for i in 0..input.capacity() {
1307                            if unsafe { !input.visibility().is_set_unchecked(i) } {
1308                                continue;
1309                            }
1310                            #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
1311                            for output in #iter {
1312                                index_builder.append(Some(i as i32));
1313                                #append_output
1314
1315                                if index_builder.len() == self.chunk_size {
1316                                    let len = index_builder.len();
1317                                    let index_array = std::mem::replace(&mut index_builder, I32ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1318                                    let value_arrays = [#(std::mem::replace(&mut #builders, #builder_types::with_type(self.chunk_size, #return_types)).finish().into_ref()),*];
1319                                    #build_value_array
1320                                    let error_array = std::mem::replace(&mut error_builder, Utf8ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1321                                    if error_array.null_bitmap().any() {
1322                                        yield DataChunk::new(vec![index_array, value_array, error_array], self.chunk_size);
1323                                    } else {
1324                                        yield DataChunk::new(vec![index_array, value_array], self.chunk_size);
1325                                    }
1326                                }
1327                            }
1328                        }
1329
1330                        if index_builder.len() > 0 {
1331                            let len = index_builder.len();
1332                            let index_array = index_builder.finish().into_ref();
1333                            let value_arrays = [#(#builders.finish().into_ref()),*];
1334                            #build_value_array
1335                            let error_array = error_builder.finish().into_ref();
1336                            if error_array.null_bitmap().any() {
1337                                yield DataChunk::new(vec![index_array, value_array, error_array], len);
1338                            } else {
1339                                yield DataChunk::new(vec![index_array, value_array], len);
1340                            }
1341                        }
1342                    }
1343                }
1344
1345                Ok(Box::new(#struct_name {
1346                    context,
1347                    chunk_size,
1348                    #(#child,)*
1349                    prebuilt_arg: #prebuilt_arg_value,
1350                }))
1351            }
1352        })
1353    }
1354}
1355
1356fn sig_data_type(ty: &str) -> TokenStream2 {
1357    match ty {
1358        "any" => quote! { SigDataType::Any },
1359        "anyarray" => quote! { SigDataType::AnyArray },
1360        "anymap" => quote! { SigDataType::AnyMap },
1361        "vector" => quote! { SigDataType::Vector },
1362        "struct" => quote! { SigDataType::AnyStruct },
1363        _ if ty.starts_with("struct") && ty.contains("any") => quote! { SigDataType::AnyStruct },
1364        _ => {
1365            let datatype = data_type(ty);
1366            quote! { SigDataType::Exact(#datatype) }
1367        }
1368    }
1369}
1370
1371fn data_type(ty: &str) -> TokenStream2 {
1372    if let Some(ty) = ty.strip_suffix("[]") {
1373        let inner_type = data_type(ty);
1374        return quote! { DataType::list(#inner_type) };
1375    }
1376    if ty.starts_with("struct<") {
1377        return quote! { DataType::Struct(#ty.parse().expect("invalid struct type")) };
1378    }
1379    let variant = format_ident!("{}", types::data_type(ty));
1380    // TODO: enable the check
1381    // assert!(
1382    //     !matches!(ty, "any" | "anyarray" | "anymap" | "struct"),
1383    //     "{ty}, {variant}"
1384    // );
1385
1386    quote! { DataType::#variant }
1387}
1388
1389/// Extract multiple output types.
1390///
1391/// ```ignore
1392/// output_types("int4") -> ["int4"]
1393/// output_types("struct<key varchar, value jsonb>") -> ["varchar", "jsonb"]
1394/// ```
1395fn output_types(ty: &str) -> Vec<&str> {
1396    if let Some(s) = ty.strip_prefix("struct<")
1397        && let Some(args) = s.strip_suffix('>')
1398    {
1399        args.split(',')
1400            .map(|s| s.split_whitespace().nth(1).unwrap())
1401            .collect()
1402    } else {
1403        vec![ty]
1404    }
1405}