1use itertools::Itertools;
18use proc_macro2::{Ident, Span};
19use quote::{format_ident, quote};
20
21use super::*;
22
23impl FunctionAttr {
24 pub fn expand(&self) -> Vec<Self> {
26 if self
28 .args
29 .last()
30 .is_some_and(|arg| arg.starts_with("variadic"))
31 {
32 let mut attrs = Vec::new();
36 attrs.extend(
37 FunctionAttr {
38 args: {
39 let mut args = self.args.clone();
40 *args.last_mut().unwrap() = "...".to_owned();
41 args
42 },
43 ..self.clone()
44 }
45 .expand(),
46 );
47 attrs.extend(
48 FunctionAttr {
49 name: format!("{}_variadic", self.name),
50 args: {
51 let mut args = self.args.clone();
52 let last = args.last_mut().unwrap();
53 *last = last.strip_prefix("variadic ").unwrap().into();
54 args
55 },
56 ..self.clone()
57 }
58 .expand(),
59 );
60 return attrs;
61 }
62 let args = self.args.iter().map(|ty| types::expand_type_wildcard(ty));
63 let ret = types::expand_type_wildcard(&self.ret);
64 let mut attrs = Vec::new();
65 for (args, mut ret) in args.multi_cartesian_product().cartesian_product(ret) {
66 if ret == "auto" {
67 ret = types::min_compatible_type(&args);
68 }
69 let attr = FunctionAttr {
70 args: args.iter().map(|s| s.to_string()).collect(),
71 ret: ret.to_owned(),
72 ..self.clone()
73 };
74 attrs.push(attr);
75 }
76 attrs
77 }
78
79 fn generate_type_infer_fn(&self) -> Result<TokenStream2> {
81 if let Some(func) = &self.type_infer {
82 if func == "unreachable" {
83 return Ok(
84 quote! { |_| unreachable!("type inference for this function should be specially handled in frontend, and should not call sig.type_infer") },
85 );
86 }
87 return Ok(func.parse().unwrap());
89 } else if self.ret == "any" {
90 if let Some(i) = self.args.iter().position(|t| t == "any") {
92 return Ok(quote! { |args| Ok(args[#i].clone()) });
94 }
95 if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
96 return Ok(quote! { |args| Ok(args[#i].as_list_elem().clone()) });
98 }
99 } else if self.ret == "anyarray" {
100 if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
101 return Ok(quote! { |args| Ok(args[#i].clone()) });
103 }
104 if let Some(i) = self.args.iter().position(|t| t == "any") {
105 return Ok(quote! { |args| Ok(DataType::list(args[#i].clone())) });
107 }
108 } else if self.ret == "struct" {
109 if let Some(i) = self.args.iter().position(|t| t == "struct") {
110 return Ok(quote! { |args| Ok(args[#i].clone()) });
112 }
113 } else if self.ret == "anymap" {
114 if let Some(i) = self.args.iter().position(|t| t == "anymap") {
115 return Ok(quote! { |args| Ok(args[#i].clone()) });
117 }
118 } else {
119 let ty = data_type(&self.ret);
121 return Ok(quote! { |_| Ok(#ty) });
122 }
123 Err(Error::new(
124 Span::call_site(),
125 "type inference function cannot be automatically derived. You should provide: `type_infer = \"|args| Ok(...)\"`",
126 ))
127 }
128
129 pub fn generate_function_descriptor(
137 &self,
138 user_fn: &UserFunctionAttr,
139 build_fn: bool,
140 ) -> Result<TokenStream2> {
141 if self.is_table_function {
142 return self.generate_table_function_descriptor(user_fn, build_fn);
143 }
144 let name = self.name.clone();
145 let variadic = matches!(self.args.last(), Some(t) if t == "...");
146 let args = match variadic {
147 true => &self.args[..self.args.len() - 1],
148 false => &self.args[..],
149 }
150 .iter()
151 .map(|ty| sig_data_type(ty))
152 .collect_vec();
153 let ret = sig_data_type(&self.ret);
154
155 let pb_type = format_ident!("{}", utils::to_camel_case(&name));
156 let ctor_name = format_ident!("{}", self.ident_name());
157 let build_fn = if build_fn {
158 let name = format_ident!("{}", user_fn.name);
159 quote! { #name }
160 } else if self.rewritten {
161 quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
162 } else {
163 self.generate_build_scalar_function(user_fn, true)?
165 };
166 let type_infer_fn = self.generate_type_infer_fn()?;
167 let deprecated = self.deprecated;
168
169 Ok(quote! {
170 #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
171 fn #ctor_name() -> risingwave_expr::sig::FuncSign {
172 use risingwave_common::types::{DataType, DataTypeName};
173 use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
174
175 FuncSign {
176 name: risingwave_pb::expr::expr_node::Type::#pb_type.into(),
177 inputs_type: vec![#(#args),*],
178 variadic: #variadic,
179 ret_type: #ret,
180 build: FuncBuilder::Scalar(#build_fn),
181 type_infer: #type_infer_fn,
182 deprecated: #deprecated,
183 }
184 }
185 })
186 }
187
188 fn generate_build_scalar_function(
193 &self,
194 user_fn: &UserFunctionAttr,
195 optimize_const: bool,
196 ) -> Result<TokenStream2> {
197 let variadic = matches!(self.args.last(), Some(t) if t == "...");
198 let num_args = self.args.len() - if variadic { 1 } else { 0 };
199 let fn_name = format_ident!("{}", user_fn.name);
200 let struct_name = match optimize_const {
201 true => format_ident!("{}OptimizeConst", utils::to_camel_case(&self.ident_name())),
202 false => format_ident!("{}", utils::to_camel_case(&self.ident_name())),
203 };
204
205 let prebuilt_indices = match &self.prebuild {
220 Some(s) => (0..num_args)
221 .filter(|i| s.contains(&format!("${i}")))
222 .collect_vec(),
223 None => vec![],
224 };
225 let non_prebuilt_indices = match &self.prebuild {
226 Some(s) => (0..num_args)
227 .filter(|i| !s.contains(&format!("${i}")))
228 .collect_vec(),
229 _ => (0..num_args).collect_vec(),
230 };
231 let children_indices = match optimize_const {
232 #[allow(clippy::redundant_clone)] true => non_prebuilt_indices.clone(),
234 false => (0..num_args).collect_vec(),
235 };
236
237 fn idents(prefix: &str, indices: &[usize]) -> Vec<Ident> {
239 indices
240 .iter()
241 .map(|i| format_ident!("{prefix}{i}"))
242 .collect()
243 }
244 let inputs = idents("i", &children_indices);
245 let prebuilt_inputs = idents("i", &prebuilt_indices);
246 let non_prebuilt_inputs = idents("i", &non_prebuilt_indices);
247 let array_refs = idents("array", &children_indices);
248 let arrays = idents("a", &children_indices);
249 let datums = idents("v", &children_indices);
250 let arg_arrays = children_indices
251 .iter()
252 .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
253 let arg_types = children_indices.iter().map(|i| {
254 types::ref_type(&self.args[*i])
255 .parse::<TokenStream2>()
256 .unwrap()
257 });
258 let annotation: TokenStream2 = match user_fn.core_return_type.as_str() {
259 "T" | "T1" | "T2" | "T3" => format!(": Option<{}>", types::owned_type(&self.ret))
261 .parse()
262 .unwrap(),
263 _ => quote! {},
264 };
265 let ret_array_type = format_ident!("{}", types::array_type(&self.ret));
266 let builder_type = format_ident!("{}Builder", types::array_type(&self.ret));
267 let prebuilt_arg_type = match &self.prebuild {
268 Some(s) if optimize_const => s.split("::").next().unwrap().parse().unwrap(),
269 _ => quote! { () },
270 };
271 let prebuilt_arg_value = match &self.prebuild {
272 Some(s) => s
276 .replace('$', "i")
277 .parse()
278 .expect("invalid prebuild syntax"),
279 None => quote! { () },
280 };
281 let prebuild_const = if self.prebuild.is_some() && optimize_const {
282 let build_general = self.generate_build_scalar_function(user_fn, false)?;
283 quote! {{
284 let build_general = #build_general;
285 #(
286 let #prebuilt_inputs = match children[#prebuilt_indices].eval_const() {
288 Ok(s) => s,
289 Err(_) => return build_general(return_type, children),
291 };
292 let #prebuilt_inputs = match &#prebuilt_inputs {
294 Some(s) => s.as_scalar_ref_impl().try_into()?,
295 None => return Ok(Box::new(risingwave_expr::expr::LiteralExpression::new(
297 return_type,
298 None,
299 ))),
300 };
301 )*
302 #prebuilt_arg_value
303 }}
304 } else {
305 quote! { () }
306 };
307
308 let check_children = match variadic {
310 true => quote! { risingwave_expr::ensure!(children.len() >= #num_args); },
311 false => quote! { risingwave_expr::ensure!(children.len() == #num_args); },
312 };
313
314 let eval_variadic = variadic.then(|| {
316 quote! {
317 let mut columns = Vec::with_capacity(self.children.len() - #num_args);
318 for child in &self.children[#num_args..] {
319 columns.push(child.eval(input).await?);
320 }
321 let variadic_input = DataChunk::new(columns, input.visibility().clone());
322 }
323 });
324 let eval_row_variadic = variadic.then(|| {
326 quote! {
327 let mut row = Vec::with_capacity(self.children.len() - #num_args);
328 for child in &self.children[#num_args..] {
329 row.push(child.eval_row(input).await?);
330 }
331 let variadic_row = OwnedRow::new(row);
332 }
333 });
334
335 let generic = (self.ret == "boolean" && user_fn.generic == 3).then(|| {
336 let compatible_type = types::ref_type(types::min_compatible_type(&self.args))
338 .parse::<TokenStream2>()
339 .unwrap();
340 quote! { ::<_, _, #compatible_type> }
341 });
342 let prebuilt_arg = match (&self.prebuild, optimize_const) {
343 (Some(_), true) => quote! { &self.prebuilt_arg, },
345 (Some(_), false) => quote! { &#prebuilt_arg_value, },
347 (None, _) => quote! {},
349 };
350 let variadic_args = variadic.then(|| quote! { &variadic_row, });
351 let context = user_fn.context.then(|| quote! { &self.context, });
352 let writer = user_fn
353 .writer_type_kind
354 .is_some()
355 .then(|| quote! { &mut writer, });
356 let await_ = user_fn.async_.then(|| quote! { .await });
357
358 let record_error = {
359 #[allow(clippy::disallowed_methods)] let inputs_args = inputs
362 .iter()
363 .zip(user_fn.args_option.iter())
364 .map(|(input, opt)| {
365 if *opt {
366 quote! { #input.map(|s| ScalarRefImpl::from(s)) }
367 } else {
368 quote! { Some(ScalarRefImpl::from(#input)) }
369 }
370 });
371 let inputs_args = quote! {
372 let args: &[DatumRef<'_>] = &[#(#inputs_args),*];
373 let args = args.iter().copied();
374 };
375 let var_args = variadic.then(|| {
376 quote! {
377 let args = args.chain(variadic_row.iter());
378 }
379 });
380
381 quote! {
382 #inputs_args
383 #var_args
384 errors.push(ExprError::function(
385 stringify!(#fn_name),
386 args,
387 e,
388 ));
389 }
390 };
391
392 let mut output = quote! { #fn_name #generic(
395 #(#non_prebuilt_inputs,)*
396 #variadic_args
397 #prebuilt_arg
398 #context
399 #writer
400 ) #await_ };
401 output = match user_fn.return_type_kind {
404 _ if self.ret == "void" => quote! { { #output; Option::<i32>::None } },
406 ReturnTypeKind::T => quote! { Some(#output) },
407 ReturnTypeKind::Option => output,
408 ReturnTypeKind::Result => quote! {
409 match #output {
410 Ok(x) => Some(x),
411 Err(e) => {
412 #record_error
413 None
414 }
415 }
416 },
417 ReturnTypeKind::ResultOption => quote! {
418 match #output {
419 Ok(x) => x,
420 Err(e) => {
421 #record_error
422 None
423 }
424 }
425 },
426 };
427 if self.prebuild.is_some() {
430 output = quote! {
431 match (#(#inputs,)*) {
432 (#(Some(#inputs),)*) => #output,
433 _ => None,
434 }
435 };
436 } else {
437 #[allow(clippy::disallowed_methods)] let some_inputs = inputs
439 .iter()
440 .zip(user_fn.args_option.iter())
441 .map(|(input, opt)| {
442 if *opt {
443 quote! { #input }
444 } else {
445 quote! { Some(#input) }
446 }
447 });
448 output = quote! {
449 match (#(#inputs,)*) {
450 (#(#some_inputs,)*) => #output,
451 _ => None,
452 }
453 };
454 };
455 let append_output = match user_fn.writer_type_kind {
457 Some(WriterTypeKind::FmtWrite)
458 | Some(WriterTypeKind::IoWrite)
459 | Some(WriterTypeKind::ListWrite) => quote! {{
460 let mut writer = builder.writer();
461 if #output.is_some() {
462 writer.finish();
463 } else {
464 writer.rollback();
465 builder.append_null();
466 }
467 }},
468 Some(WriterTypeKind::JsonbbBuilder) => quote! {{
469 let mut writer_wrapper = builder.writer();
470 let mut writer = writer_wrapper.inner();
471 if #output.is_some() {
472 writer_wrapper.finish();
473 } else {
474 writer_wrapper.rollback();
475 builder.append_null();
476 }
477 }},
478 None if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
479 builder.append(#output.as_ref().map(|s| s.as_ref()));
480 },
481 None => quote! {
482 let output #annotation = #output;
483 builder.append(output.as_ref().map(|s| s.as_scalar_ref()));
484 },
485 };
486 let row_output = match user_fn.writer_type_kind {
488 Some(WriterTypeKind::FmtWrite) => quote! {{
489 let mut writer = String::new();
490 #output.map(|_| writer.into())
491 }},
492 Some(WriterTypeKind::IoWrite) => quote! {{
493 let mut writer = Vec::new();
494 #output.map(|_| writer.into())
495 }},
496 Some(WriterTypeKind::JsonbbBuilder) => quote! {{
497 let mut writer = jsonbb::Builder::<Vec<u8>>::new();
498 #output.map(|_| JsonbVal::from(writer.finish()).into())
499 }},
500 Some(WriterTypeKind::ListWrite) => quote! {{
501 let mut writer = {
502 let DataType::List(list_ty) = &self.context.return_type else {
503 panic!("data type must be DataType::List");
504 };
505 list_ty.elem().create_array_builder(1)
506 };
507 #output.map(|_| ListValue::new(writer.finish()).into())
508 }},
509 None if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
510 #output.map(|s| s.as_ref().into())
511 },
512 None => quote! {{
513 let output #annotation = #output;
514 output.map(|s| s.into())
515 }},
516 };
517 let eval = if let Some(batch_fn) = &self.batch_fn {
519 assert!(
520 !variadic,
521 "customized batch function is not supported for variadic functions"
522 );
523 let fn_name = format_ident!("{}", batch_fn);
525 quote! {
526 let c = #fn_name(#(#arrays),*);
527 Arc::new(c.into())
528 }
529 } else if (types::is_primitive(&self.ret) || self.ret == "boolean")
530 && user_fn.is_pure()
531 && !variadic
532 && self.prebuild.is_none()
533 {
534 match self.args.len() {
536 0 => quote! {
537 let c = #ret_array_type::from_iter_bitmap(
538 std::iter::repeat_with(|| #fn_name()).take(input.capacity()),
539 Bitmap::ones(input.capacity()),
540 );
541 Arc::new(c.into())
542 },
543 1 => quote! {
544 let c = #ret_array_type::from_iter_bitmap(
545 a0.raw_iter().map(|a| #fn_name(a)),
546 a0.null_bitmap().clone()
547 );
548 Arc::new(c.into())
549 },
550 2 => quote! {
551 #[allow(clippy::disallowed_methods)]
553 let c = #ret_array_type::from_iter_bitmap(
554 a0.raw_iter()
555 .zip(a1.raw_iter())
556 .map(|(a, b)| #fn_name #generic(a, b)),
557 a0.null_bitmap() & a1.null_bitmap(),
558 );
559 Arc::new(c.into())
560 },
561 n => todo!("SIMD optimization for {n} arguments"),
562 }
563 } else {
564 let let_variadic = variadic.then(|| {
566 quote! {
567 let variadic_row = variadic_input.row_at_unchecked_vis(i);
568 }
569 });
570 quote! {
571 let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone());
572
573 if input.is_vis_compacted() {
574 for i in 0..input.capacity() {
575 #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
576 #let_variadic
577 #append_output
578 }
579 } else {
580 for i in 0..input.capacity() {
581 if unsafe { !input.visibility().is_set_unchecked(i) } {
582 builder.append_null();
583 continue;
584 }
585 #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
586 #let_variadic
587 #append_output
588 }
589 }
590 Arc::new(builder.finish().into())
591 }
592 };
593
594 Ok(quote! {
595 |return_type: DataType, children: Vec<risingwave_expr::expr::BoxedExpression>|
596 -> risingwave_expr::Result<risingwave_expr::expr::BoxedExpression>
597 {
598 use std::sync::Arc;
599 use risingwave_common::array::*;
600 use risingwave_common::types::*;
601 use risingwave_common::bitmap::Bitmap;
602 use risingwave_common::row::OwnedRow;
603 use risingwave_common::util::iter_util::ZipEqFast;
604
605 use risingwave_expr::expr::{Context, BoxedExpression};
606 use risingwave_expr::{ExprError, Result};
607 use risingwave_expr::codegen::*;
608
609 #check_children
610 let prebuilt_arg = #prebuild_const;
611 let context = Context {
612 return_type,
613 arg_types: children.iter().map(|c| c.return_type()).collect(),
614 variadic: #variadic,
615 };
616
617 #[derive(Debug)]
618 struct #struct_name {
619 context: Context,
620 children: Vec<BoxedExpression>,
621 prebuilt_arg: #prebuilt_arg_type,
622 }
623 #[async_trait]
624 impl risingwave_expr::expr::Expression for #struct_name {
625 fn return_type(&self) -> DataType {
626 self.context.return_type.clone()
627 }
628 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
629 #(
630 let #array_refs = self.children[#children_indices].eval(input).await?;
631 let #arrays: &#arg_arrays = #array_refs.as_ref().into();
632 )*
633 #eval_variadic
634 let mut errors = vec![];
635 let array = { #eval };
636 if errors.is_empty() {
637 Ok(array)
638 } else {
639 Err(ExprError::Multiple(array, errors.into()))
640 }
641 }
642 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
643 #(
644 let #datums = self.children[#children_indices].eval_row(input).await?;
645 let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap());
646 )*
647 #eval_row_variadic
648 let mut errors: Vec<ExprError> = vec![];
649 let output = #row_output;
650 if let Some(err) = errors.into_iter().next() {
651 Err(err.into())
652 } else {
653 Ok(output)
654 }
655 }
656 }
657
658 Ok(Box::new(#struct_name {
659 context,
660 children,
661 prebuilt_arg,
662 }))
663 }
664 })
665 }
666
667 pub fn generate_aggregate_descriptor(
673 &self,
674 user_fn: &AggregateFnOrImpl,
675 build_fn: bool,
676 ) -> Result<TokenStream2> {
677 let name = self.name.clone();
678
679 let mut args = Vec::with_capacity(self.args.len());
680 for ty in &self.args {
681 args.push(sig_data_type(ty));
682 }
683 let ret = sig_data_type(&self.ret);
684 let state_type = match &self.state {
685 Some(ty) if ty != "ref" => {
686 let ty = data_type(ty);
687 quote! { Some(#ty) }
688 }
689 _ => quote! { None },
690 };
691 let append_only = match build_fn {
692 false => !user_fn.has_retract(),
693 true => self.append_only,
694 };
695
696 let pb_kind = format_ident!("{}", utils::to_camel_case(&name));
697 let ctor_name = match append_only {
698 false => format_ident!("{}", self.ident_name()),
699 true => format_ident!("{}_append_only", self.ident_name()),
700 };
701 let build_fn = if build_fn {
702 let name = format_ident!("{}", user_fn.as_fn().name);
703 quote! { #name }
704 } else if self.rewritten {
705 quote! { |_| Err(ExprError::UnsupportedFunction(#name.into())) }
706 } else {
707 self.generate_agg_build_fn(user_fn)?
708 };
709 let build_retractable = match append_only {
710 true => quote! { None },
711 false => quote! { Some(#build_fn) },
712 };
713 let build_append_only = match append_only {
714 false => quote! { None },
715 true => quote! { Some(#build_fn) },
716 };
717 let retractable_state_type = match append_only {
718 true => quote! { None },
719 false => state_type.clone(),
720 };
721 let append_only_state_type = match append_only {
722 false => quote! { None },
723 true => state_type,
724 };
725 let type_infer_fn = self.generate_type_infer_fn()?;
726 let deprecated = self.deprecated;
727
728 Ok(quote! {
729 #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
730 fn #ctor_name() -> risingwave_expr::sig::FuncSign {
731 use risingwave_common::types::{DataType, DataTypeName};
732 use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
733
734 FuncSign {
735 name: risingwave_pb::expr::agg_call::PbKind::#pb_kind.into(),
736 inputs_type: vec![#(#args),*],
737 variadic: false,
738 ret_type: #ret,
739 build: FuncBuilder::Aggregate {
740 retractable: #build_retractable,
741 append_only: #build_append_only,
742 retractable_state_type: #retractable_state_type,
743 append_only_state_type: #append_only_state_type,
744 },
745 type_infer: #type_infer_fn,
746 deprecated: #deprecated,
747 }
748 }
749 })
750 }
751
752 fn generate_agg_build_fn(&self, user_fn: &AggregateFnOrImpl) -> Result<TokenStream2> {
754 let custom_state = user_fn.accumulate().first_mut_ref_arg.as_ref();
757 let state_type: TokenStream2 = match (custom_state, &self.state) {
758 (Some(s), _) => s.parse().unwrap(),
759 (_, Some(state)) if state == "ref" => types::ref_type(&self.ret).parse().unwrap(),
760 (_, Some(state)) if state != "ref" => types::owned_type(state).parse().unwrap(),
761 _ => types::owned_type(&self.ret).parse().unwrap(),
762 };
763 let let_arrays = self
764 .args
765 .iter()
766 .enumerate()
767 .map(|(i, arg)| {
768 let array = format_ident!("a{i}");
769 let array_type: TokenStream2 = types::array_type(arg).parse().unwrap();
770 quote! {
771 let #array: &#array_type = input.column_at(#i).as_ref().into();
772 }
773 })
774 .collect_vec();
775 let let_values = (0..self.args.len())
776 .map(|i| {
777 let v = format_ident!("v{i}");
778 let a = format_ident!("a{i}");
779 quote! { let #v = unsafe { #a.value_at_unchecked(row_id) }; }
780 })
781 .collect_vec();
782 let downcast_state = if custom_state.is_some() {
783 quote! { let mut state: &mut #state_type = state0.downcast_mut(); }
784 } else if let Some(s) = &self.state
785 && s == "ref"
786 {
787 quote! { let mut state: Option<#state_type> = state0.as_datum_mut().as_ref().map(|x| x.as_scalar_ref_impl().try_into().unwrap()); }
788 } else {
789 quote! { let mut state: Option<#state_type> = state0.as_datum_mut().take().map(|s| s.try_into().unwrap()); }
790 };
791 let restore_state = if custom_state.is_some() {
792 quote! {}
793 } else if let Some(s) = &self.state
794 && s == "ref"
795 {
796 quote! { *state0.as_datum_mut() = state.map(|x| x.to_owned_scalar().into()); }
797 } else {
798 quote! { *state0.as_datum_mut() = state.map(|s| s.into()); }
799 };
800 let create_state = if custom_state.is_some() {
801 quote! {
802 fn create_state(&self) -> Result<AggregateState> {
803 Ok(AggregateState::Any(Box::<#state_type>::default()))
804 }
805 }
806 } else if let Some(state) = &self.init_state {
807 let state: TokenStream2 = state.parse().unwrap();
808 quote! {
809 fn create_state(&self) -> Result<AggregateState> {
810 Ok(AggregateState::Datum(Some(#state.into())))
811 }
812 }
813 } else {
814 quote! {}
816 };
817 let args = (0..self.args.len()).map(|i| format_ident!("v{i}"));
818 let args = quote! { #(#args,)* };
819 let panic_on_retract = {
820 let msg = format!(
821 "attempt to retract on aggregate function {}, but it is append-only",
822 self.name
823 );
824 quote! { assert_eq!(op, Op::Insert, #msg); }
825 };
826 let mut next_state = match user_fn {
827 AggregateFnOrImpl::Fn(f) => {
828 let context = f.context.then(|| quote! { &self.context, });
829 let fn_name = format_ident!("{}", f.name);
830 match f.retract {
831 true => {
832 quote! { #fn_name(state, #args matches!(op, Op::Delete | Op::UpdateDelete) #context) }
833 }
834 false => quote! {{
835 #panic_on_retract
836 #fn_name(state, #args #context)
837 }},
838 }
839 }
840 AggregateFnOrImpl::Impl(i) => {
841 let retract = match i.retract {
842 Some(_) => quote! { self.function.retract(state, #args) },
843 None => panic_on_retract,
844 };
845 quote! {
846 if matches!(op, Op::Delete | Op::UpdateDelete) {
847 #retract
848 } else {
849 self.function.accumulate(state, #args)
850 }
851 }
852 }
853 };
854 next_state = match user_fn.accumulate().return_type_kind {
855 ReturnTypeKind::T => quote! { Some(#next_state) },
856 ReturnTypeKind::Option => next_state,
857 ReturnTypeKind::Result => quote! { Some(#next_state?) },
858 ReturnTypeKind::ResultOption => quote! { #next_state? },
859 };
860 if user_fn.accumulate().args_option.iter().all(|b| !b) {
861 match self.args.len() {
862 0 => {
863 next_state = quote! {
864 match state {
865 Some(state) => #next_state,
866 None => state,
867 }
868 };
869 }
870 1 => {
871 let first_state = if self.init_state.is_some() {
872 quote! { unreachable!() }
874 } else if let Some(s) = &self.state
875 && s == "ref"
876 {
877 quote! { Some(v0) }
879 } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
880 && impl_.create_state.is_some()
881 {
882 quote! {{
884 let state = self.function.create_state();
885 #next_state
886 }}
887 } else {
888 quote! {{
889 let state = #state_type::default();
890 #next_state
891 }}
892 };
893 next_state = quote! {
894 match (state, v0) {
895 (Some(state), Some(v0)) => #next_state,
896 (None, Some(v0)) => #first_state,
897 (state, None) => state,
898 }
899 };
900 }
901 _ => todo!("multiple arguments are not supported for non-option function"),
902 }
903 }
904 let update_state = if custom_state.is_some() {
905 quote! { _ = #next_state; }
906 } else {
907 quote! { state = #next_state; }
908 };
909 let get_result = if custom_state.is_some() {
910 quote! { Ok(state.downcast_ref::<#state_type>().into()) }
911 } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
912 && impl_.finalize.is_some()
913 {
914 quote! {
915 let state = match state.as_datum() {
916 Some(s) => s.as_scalar_ref_impl().try_into().unwrap(),
917 None => return Ok(None),
918 };
919 Ok(Some(self.function.finalize(state).into()))
920 }
921 } else {
922 quote! { Ok(state.as_datum().clone()) }
923 };
924 let function_field = match user_fn {
925 AggregateFnOrImpl::Fn(_) => quote! {},
926 AggregateFnOrImpl::Impl(i) => {
927 let struct_name = format_ident!("{}", i.struct_name);
928 let generic = self.generic.as_ref().map(|g| {
929 let g = format_ident!("{g}");
930 quote! { <#g> }
931 });
932 quote! { function: #struct_name #generic, }
933 }
934 };
935 let function_new = match user_fn {
936 AggregateFnOrImpl::Fn(_) => quote! {},
937 AggregateFnOrImpl::Impl(i) => {
938 let struct_name = format_ident!("{}", i.struct_name);
939 let generic = self.generic.as_ref().map(|g| {
940 let g = format_ident!("{g}");
941 quote! { ::<#g> }
942 });
943 quote! { function: #struct_name #generic :: default(), }
944 }
945 };
946
947 Ok(quote! {
948 |agg| {
949 use std::collections::HashSet;
950 use std::ops::Range;
951 use risingwave_common::array::*;
952 use risingwave_common::types::*;
953 use risingwave_common::bail;
954 use risingwave_common::bitmap::Bitmap;
955 use risingwave_common_estimate_size::EstimateSize;
956
957 use risingwave_expr::expr::Context;
958 use risingwave_expr::Result;
959 use risingwave_expr::aggregate::AggregateState;
960 use risingwave_expr::codegen::async_trait;
961
962 let context = Context {
963 return_type: agg.return_type.clone(),
964 arg_types: agg.args.arg_types().to_owned(),
965 variadic: false,
966 };
967
968 struct Agg {
969 context: Context,
970 #function_field
971 }
972
973 #[async_trait]
974 impl risingwave_expr::aggregate::AggregateFunction for Agg {
975 fn return_type(&self) -> DataType {
976 self.context.return_type.clone()
977 }
978
979 #create_state
980
981 async fn update(&self, state0: &mut AggregateState, input: &StreamChunk) -> Result<()> {
982 #(#let_arrays)*
983 #downcast_state
984 for row_id in input.visibility().iter_ones() {
985 let op = unsafe { *input.ops().get_unchecked(row_id) };
986 #(#let_values)*
987 #update_state
988 }
989 #restore_state
990 Ok(())
991 }
992
993 async fn update_range(&self, state0: &mut AggregateState, input: &StreamChunk, range: Range<usize>) -> Result<()> {
994 assert!(range.end <= input.capacity());
995 #(#let_arrays)*
996 #downcast_state
997 if input.is_vis_compacted() {
998 for row_id in range {
999 let op = unsafe { *input.ops().get_unchecked(row_id) };
1000 #(#let_values)*
1001 #update_state
1002 }
1003 } else {
1004 for row_id in input.visibility().iter_ones() {
1005 if row_id < range.start {
1006 continue;
1007 } else if row_id >= range.end {
1008 break;
1009 }
1010 let op = unsafe { *input.ops().get_unchecked(row_id) };
1011 #(#let_values)*
1012 #update_state
1013 }
1014 }
1015 #restore_state
1016 Ok(())
1017 }
1018
1019 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
1020 #get_result
1021 }
1022 }
1023
1024 Ok(Box::new(Agg {
1025 context,
1026 #function_new
1027 }))
1028 }
1029 })
1030 }
1031
1032 fn generate_table_function_descriptor(
1036 &self,
1037 user_fn: &UserFunctionAttr,
1038 build_fn: bool,
1039 ) -> Result<TokenStream2> {
1040 let name = self.name.clone();
1041 let mut args = Vec::with_capacity(self.args.len());
1042 for ty in &self.args {
1043 args.push(sig_data_type(ty));
1044 }
1045 let ret = sig_data_type(&self.ret);
1046
1047 let pb_type = format_ident!("{}", utils::to_camel_case(&name));
1048 let ctor_name = format_ident!("{}", self.ident_name());
1049 let build_fn = if build_fn {
1050 let name = format_ident!("{}", user_fn.name);
1051 quote! { #name }
1052 } else if self.rewritten {
1053 quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
1054 } else {
1055 self.generate_build_table_function(user_fn)?
1056 };
1057 let type_infer_fn = self.generate_type_infer_fn()?;
1058 let deprecated = self.deprecated;
1059
1060 Ok(quote! {
1061 #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
1062 fn #ctor_name() -> risingwave_expr::sig::FuncSign {
1063 use risingwave_common::types::{DataType, DataTypeName};
1064 use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
1065
1066 FuncSign {
1067 name: risingwave_pb::expr::table_function::Type::#pb_type.into(),
1068 inputs_type: vec![#(#args),*],
1069 variadic: false,
1070 ret_type: #ret,
1071 build: FuncBuilder::Table(#build_fn),
1072 type_infer: #type_infer_fn,
1073 deprecated: #deprecated,
1074 }
1075 }
1076 })
1077 }
1078
1079 fn generate_build_table_function(&self, user_fn: &UserFunctionAttr) -> Result<TokenStream2> {
1080 let num_args = self.args.len();
1081 let return_types = output_types(&self.ret);
1082 let fn_name = format_ident!("{}", user_fn.name);
1083 let struct_name = format_ident!("{}", utils::to_camel_case(&self.ident_name()));
1084 let arg_ids = (0..num_args)
1085 .filter(|i| match &self.prebuild {
1086 Some(s) => !s.contains(&format!("${i}")),
1087 None => true,
1088 })
1089 .collect_vec();
1090 let const_ids = (0..num_args).filter(|i| match &self.prebuild {
1091 Some(s) => s.contains(&format!("${i}")),
1092 None => false,
1093 });
1094 let inputs: Vec<_> = arg_ids.iter().map(|i| format_ident!("i{i}")).collect();
1095 let all_child: Vec<_> = (0..num_args).map(|i| format_ident!("child{i}")).collect();
1096 let const_child: Vec<_> = const_ids.map(|i| format_ident!("child{i}")).collect();
1097 let child: Vec<_> = arg_ids.iter().map(|i| format_ident!("child{i}")).collect();
1098 let array_refs: Vec<_> = arg_ids.iter().map(|i| format_ident!("array{i}")).collect();
1099 let arrays: Vec<_> = arg_ids.iter().map(|i| format_ident!("a{i}")).collect();
1100 let arg_arrays = arg_ids
1101 .iter()
1102 .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
1103 let outputs = (0..return_types.len())
1104 .map(|i| format_ident!("o{i}"))
1105 .collect_vec();
1106 let builders = (0..return_types.len())
1107 .map(|i| format_ident!("builder{i}"))
1108 .collect_vec();
1109 let builder_types = return_types
1110 .iter()
1111 .map(|ty| format_ident!("{}Builder", types::array_type(ty)))
1112 .collect_vec();
1113 let return_types = if return_types.len() == 1 {
1114 vec![quote! { self.context.return_type.clone() }]
1115 } else {
1116 (0..return_types.len())
1117 .map(|i| quote! { self.context.return_type.as_struct().types().nth(#i).unwrap().clone() })
1118 .collect()
1119 };
1120 #[allow(clippy::disallowed_methods)]
1121 let optioned_outputs = user_fn
1122 .core_return_type
1123 .split(',')
1124 .map(|t| t.contains("Option"))
1125 .zip(&outputs)
1127 .map(|(optional, o)| match optional {
1128 false => quote! { Some(#o.as_scalar_ref()) },
1129 true => quote! { #o.map(|o| o.as_scalar_ref()) },
1130 })
1131 .collect_vec();
1132 let build_value_array = if return_types.len() == 1 {
1133 quote! { let [value_array] = value_arrays; }
1134 } else {
1135 quote! {
1136 let value_array = StructArray::new(
1137 self.context.return_type.as_struct().clone(),
1138 value_arrays.to_vec(),
1139 Bitmap::ones(len),
1140 ).into_ref();
1141 }
1142 };
1143 let context = user_fn.context.then(|| quote! { &self.context, });
1144 let prebuilt_arg = match &self.prebuild {
1145 Some(_) => quote! { &self.prebuilt_arg, },
1146 None => quote! {},
1147 };
1148 let prebuilt_arg_type = match &self.prebuild {
1149 Some(s) => s.split("::").next().unwrap().parse().unwrap(),
1150 None => quote! { () },
1151 };
1152 let prebuilt_arg_value = match &self.prebuild {
1153 Some(s) => s
1154 .replace('$', "child")
1155 .parse()
1156 .expect("invalid prebuild syntax"),
1157 None => quote! { () },
1158 };
1159 let iter = quote! { #fn_name(#(#inputs,)* #prebuilt_arg #context) };
1160 let mut iter = match user_fn.return_type_kind {
1161 ReturnTypeKind::T => quote! { #iter },
1162 ReturnTypeKind::Option => quote! { match #iter {
1163 Some(it) => it,
1164 None => continue,
1165 } },
1166 ReturnTypeKind::Result => quote! { match #iter {
1167 Ok(it) => it,
1168 Err(e) => {
1169 index_builder.append(Some(i as i32));
1170 #(#builders.append_null();)*
1171 error_builder.append_display(Some(e.as_report()));
1172 continue;
1173 }
1174 } },
1175 ReturnTypeKind::ResultOption => quote! { match #iter {
1176 Ok(Some(it)) => it,
1177 Ok(None) => continue,
1178 Err(e) => {
1179 index_builder.append(Some(i as i32));
1180 #(#builders.append_null();)*
1181 error_builder.append_display(Some(e.as_report()));
1182 continue;
1183 }
1184 } },
1185 };
1186 #[allow(clippy::disallowed_methods)] let some_inputs = inputs
1190 .iter()
1191 .zip(user_fn.args_option.iter())
1192 .map(|(input, opt)| {
1193 if *opt {
1194 quote! { #input }
1195 } else {
1196 quote! { Some(#input) }
1197 }
1198 });
1199 iter = quote! {
1200 match (#(#inputs,)*) {
1201 (#(#some_inputs,)*) => #iter,
1202 _ => continue,
1203 }
1204 };
1205 let iterator_item_type = user_fn.iterator_item_kind.clone().ok_or_else(|| {
1206 Error::new(
1207 user_fn.return_type_span,
1208 "expect `impl Iterator` in return type",
1209 )
1210 })?;
1211 let append_output = match iterator_item_type {
1212 ReturnTypeKind::T => quote! {
1213 let (#(#outputs),*) = output;
1214 #(#builders.append(#optioned_outputs);)* error_builder.append_null();
1215 },
1216 ReturnTypeKind::Option => quote! { match output {
1217 Some((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1218 None => { #(#builders.append_null();)* error_builder.append_null(); }
1219 } },
1220 ReturnTypeKind::Result => quote! { match output {
1221 Ok((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1222 Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1223 } },
1224 ReturnTypeKind::ResultOption => quote! { match output {
1225 Ok(Some((#(#outputs),*))) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1226 Ok(None) => { #(#builders.append_null();)* error_builder.append_null(); }
1227 Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1228 } },
1229 };
1230
1231 Ok(quote! {
1232 |return_type, chunk_size, children| {
1233 use risingwave_common::array::*;
1234 use risingwave_common::types::*;
1235 use risingwave_common::bitmap::Bitmap;
1236 use risingwave_common::util::iter_util::ZipEqFast;
1237 use risingwave_expr::expr::{BoxedExpression, Context};
1238 use risingwave_expr::{Result, ExprError};
1239 use risingwave_expr::codegen::*;
1240
1241 risingwave_expr::ensure!(children.len() == #num_args);
1242
1243 let context = Context {
1244 return_type: return_type.clone(),
1245 arg_types: children.iter().map(|c| c.return_type()).collect(),
1246 variadic: false,
1247 };
1248
1249 let mut iter = children.into_iter();
1250 #(let #all_child = iter.next().unwrap();)*
1251 #(
1252 let #const_child = #const_child.eval_const()?;
1253 let #const_child = match &#const_child {
1254 Some(s) => s.as_scalar_ref_impl().try_into()?,
1255 None => return Ok(risingwave_expr::table_function::empty(return_type)),
1257 };
1258 )*
1259
1260 #[derive(Debug)]
1261 struct #struct_name {
1262 context: Context,
1263 chunk_size: usize,
1264 #(#child: BoxedExpression,)*
1265 prebuilt_arg: #prebuilt_arg_type,
1266 }
1267 #[async_trait]
1268 impl risingwave_expr::table_function::TableFunction for #struct_name {
1269 fn return_type(&self) -> DataType {
1270 self.context.return_type.clone()
1271 }
1272 async fn eval<'a>(&'a self, input: &'a DataChunk) -> BoxStream<'a, Result<DataChunk>> {
1273 self.eval_inner(input)
1274 }
1275 }
1276 impl #struct_name {
1277 #[try_stream(boxed, ok = DataChunk, error = ExprError)]
1278 async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
1279 #(
1280 let #array_refs = self.#child.eval(input).await?;
1281 let #arrays: &#arg_arrays = #array_refs.as_ref().into();
1282 )*
1283
1284 let mut index_builder = I32ArrayBuilder::new(self.chunk_size);
1285 #(let mut #builders = #builder_types::with_type(self.chunk_size, #return_types);)*
1286 let mut error_builder = Utf8ArrayBuilder::new(self.chunk_size);
1287
1288 for i in 0..input.capacity() {
1289 if unsafe { !input.visibility().is_set_unchecked(i) } {
1290 continue;
1291 }
1292 #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
1293 for output in #iter {
1294 index_builder.append(Some(i as i32));
1295 #append_output
1296
1297 if index_builder.len() == self.chunk_size {
1298 let len = index_builder.len();
1299 let index_array = std::mem::replace(&mut index_builder, I32ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1300 let value_arrays = [#(std::mem::replace(&mut #builders, #builder_types::with_type(self.chunk_size, #return_types)).finish().into_ref()),*];
1301 #build_value_array
1302 let error_array = std::mem::replace(&mut error_builder, Utf8ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1303 if error_array.null_bitmap().any() {
1304 yield DataChunk::new(vec![index_array, value_array, error_array], self.chunk_size);
1305 } else {
1306 yield DataChunk::new(vec![index_array, value_array], self.chunk_size);
1307 }
1308 }
1309 }
1310 }
1311
1312 if index_builder.len() > 0 {
1313 let len = index_builder.len();
1314 let index_array = index_builder.finish().into_ref();
1315 let value_arrays = [#(#builders.finish().into_ref()),*];
1316 #build_value_array
1317 let error_array = error_builder.finish().into_ref();
1318 if error_array.null_bitmap().any() {
1319 yield DataChunk::new(vec![index_array, value_array, error_array], len);
1320 } else {
1321 yield DataChunk::new(vec![index_array, value_array], len);
1322 }
1323 }
1324 }
1325 }
1326
1327 Ok(Box::new(#struct_name {
1328 context,
1329 chunk_size,
1330 #(#child,)*
1331 prebuilt_arg: #prebuilt_arg_value,
1332 }))
1333 }
1334 })
1335 }
1336}
1337
1338fn sig_data_type(ty: &str) -> TokenStream2 {
1339 match ty {
1340 "any" => quote! { SigDataType::Any },
1341 "anyarray" => quote! { SigDataType::AnyArray },
1342 "anymap" => quote! { SigDataType::AnyMap },
1343 "vector" => quote! { SigDataType::Vector },
1344 "struct" => quote! { SigDataType::AnyStruct },
1345 _ if ty.starts_with("struct") && ty.contains("any") => quote! { SigDataType::AnyStruct },
1346 _ => {
1347 let datatype = data_type(ty);
1348 quote! { SigDataType::Exact(#datatype) }
1349 }
1350 }
1351}
1352
1353fn data_type(ty: &str) -> TokenStream2 {
1354 if let Some(ty) = ty.strip_suffix("[]") {
1355 let inner_type = data_type(ty);
1356 return quote! { DataType::list(#inner_type) };
1357 }
1358 if ty.starts_with("struct<") {
1359 return quote! { DataType::Struct(#ty.parse().expect("invalid struct type")) };
1360 }
1361 let variant = format_ident!("{}", types::data_type(ty));
1362 quote! { DataType::#variant }
1369}
1370
1371fn output_types(ty: &str) -> Vec<&str> {
1378 if let Some(s) = ty.strip_prefix("struct<")
1379 && let Some(args) = s.strip_suffix('>')
1380 {
1381 args.split(',')
1382 .map(|s| s.split_whitespace().nth(1).unwrap())
1383 .collect()
1384 } else {
1385 vec![ty]
1386 }
1387}