1use itertools::Itertools;
18use proc_macro2::{Ident, Span};
19use quote::{format_ident, quote};
20
21use super::*;
22
23impl FunctionAttr {
24 pub fn expand(&self) -> Vec<Self> {
26 if self
28 .args
29 .last()
30 .is_some_and(|arg| arg.starts_with("variadic"))
31 {
32 let mut attrs = Vec::new();
36 attrs.extend(
37 FunctionAttr {
38 args: {
39 let mut args = self.args.clone();
40 *args.last_mut().unwrap() = "...".to_owned();
41 args
42 },
43 ..self.clone()
44 }
45 .expand(),
46 );
47 attrs.extend(
48 FunctionAttr {
49 name: format!("{}_variadic", self.name),
50 args: {
51 let mut args = self.args.clone();
52 let last = args.last_mut().unwrap();
53 *last = last.strip_prefix("variadic ").unwrap().into();
54 args
55 },
56 ..self.clone()
57 }
58 .expand(),
59 );
60 return attrs;
61 }
62 let args = self.args.iter().map(|ty| types::expand_type_wildcard(ty));
63 let ret = types::expand_type_wildcard(&self.ret);
64 let mut attrs = Vec::new();
65 for (args, mut ret) in args.multi_cartesian_product().cartesian_product(ret) {
66 if ret == "auto" {
67 ret = types::min_compatible_type(&args);
68 }
69 let attr = FunctionAttr {
70 args: args.iter().map(|s| s.to_string()).collect(),
71 ret: ret.to_owned(),
72 ..self.clone()
73 };
74 attrs.push(attr);
75 }
76 attrs
77 }
78
79 fn generate_type_infer_fn(&self) -> Result<TokenStream2> {
81 if let Some(func) = &self.type_infer {
82 if func == "unreachable" {
83 return Ok(
84 quote! { |_| unreachable!("type inference for this function should be specially handled in frontend, and should not call sig.type_infer") },
85 );
86 }
87 return Ok(func.parse().unwrap());
89 } else if self.ret == "any" {
90 if let Some(i) = self.args.iter().position(|t| t == "any") {
92 return Ok(quote! { |args| Ok(args[#i].clone()) });
94 }
95 if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
96 return Ok(quote! { |args| Ok(args[#i].as_list_elem().clone()) });
98 }
99 } else if self.ret == "anyarray" {
100 if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
101 return Ok(quote! { |args| Ok(args[#i].clone()) });
103 }
104 if let Some(i) = self.args.iter().position(|t| t == "any") {
105 return Ok(quote! { |args| Ok(DataType::list(args[#i].clone())) });
107 }
108 } else if self.ret == "struct" {
109 if let Some(i) = self.args.iter().position(|t| t == "struct") {
110 return Ok(quote! { |args| Ok(args[#i].clone()) });
112 }
113 } else if self.ret == "anymap" {
114 if let Some(i) = self.args.iter().position(|t| t == "anymap") {
115 return Ok(quote! { |args| Ok(args[#i].clone()) });
117 }
118 } else {
119 let ty = data_type(&self.ret);
121 return Ok(quote! { |_| Ok(#ty) });
122 }
123 Err(Error::new(
124 Span::call_site(),
125 "type inference function cannot be automatically derived. You should provide: `type_infer = \"|args| Ok(...)\"`",
126 ))
127 }
128
129 pub fn generate_function_descriptor(
137 &self,
138 user_fn: &UserFunctionAttr,
139 build_fn: bool,
140 ) -> Result<TokenStream2> {
141 if self.is_table_function {
142 return self.generate_table_function_descriptor(user_fn, build_fn);
143 }
144 let name = self.name.clone();
145 let variadic = matches!(self.args.last(), Some(t) if t == "...");
146 let args = match variadic {
147 true => &self.args[..self.args.len() - 1],
148 false => &self.args[..],
149 }
150 .iter()
151 .map(|ty| sig_data_type(ty))
152 .collect_vec();
153 let ret = sig_data_type(&self.ret);
154
155 let pb_type = format_ident!("{}", utils::to_camel_case(&name));
156 let ctor_name = format_ident!("{}", self.ident_name());
157 let build_fn = if build_fn {
158 let name = format_ident!("{}", user_fn.name);
159 quote! { #name }
160 } else if self.rewritten {
161 quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
162 } else {
163 self.generate_build_scalar_function(user_fn, true)?
165 };
166 let type_infer_fn = self.generate_type_infer_fn()?;
167 let deprecated = self.deprecated;
168 let maybe_allow_deprecated = if deprecated {
169 quote! { #[allow(deprecated)] }
170 } else {
171 quote! {}
172 };
173
174 Ok(quote! {
175 #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
176 fn #ctor_name() -> risingwave_expr::sig::FuncSign {
177 use risingwave_common::types::{DataType, DataTypeName};
178 use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
179
180 #maybe_allow_deprecated
181 FuncSign {
182 name: risingwave_pb::expr::expr_node::Type::#pb_type.into(),
183 inputs_type: vec![#(#args),*],
184 variadic: #variadic,
185 ret_type: #ret,
186 build: FuncBuilder::Scalar(#build_fn),
187 type_infer: #type_infer_fn,
188 deprecated: #deprecated,
189 }
190 }
191 })
192 }
193
194 fn generate_build_scalar_function(
199 &self,
200 user_fn: &UserFunctionAttr,
201 optimize_const: bool,
202 ) -> Result<TokenStream2> {
203 let variadic = matches!(self.args.last(), Some(t) if t == "...");
204 let num_args = self.args.len() - if variadic { 1 } else { 0 };
205 let fn_name = format_ident!("{}", user_fn.name);
206 let struct_name = match optimize_const {
207 true => format_ident!("{}OptimizeConst", utils::to_camel_case(&self.ident_name())),
208 false => format_ident!("{}", utils::to_camel_case(&self.ident_name())),
209 };
210
211 let prebuilt_indices = match &self.prebuild {
226 Some(s) => (0..num_args)
227 .filter(|i| s.contains(&format!("${i}")))
228 .collect_vec(),
229 None => vec![],
230 };
231 let non_prebuilt_indices = match &self.prebuild {
232 Some(s) => (0..num_args)
233 .filter(|i| !s.contains(&format!("${i}")))
234 .collect_vec(),
235 _ => (0..num_args).collect_vec(),
236 };
237 let children_indices = match optimize_const {
238 #[allow(clippy::redundant_clone)] true => non_prebuilt_indices.clone(),
240 false => (0..num_args).collect_vec(),
241 };
242
243 fn idents(prefix: &str, indices: &[usize]) -> Vec<Ident> {
245 indices
246 .iter()
247 .map(|i| format_ident!("{prefix}{i}"))
248 .collect()
249 }
250 let inputs = idents("i", &children_indices);
251 let prebuilt_inputs = idents("i", &prebuilt_indices);
252 let non_prebuilt_inputs = idents("i", &non_prebuilt_indices);
253 let array_refs = idents("array", &children_indices);
254 let arrays = idents("a", &children_indices);
255 let datums = idents("v", &children_indices);
256 let arg_arrays = children_indices
257 .iter()
258 .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
259 let arg_types = children_indices.iter().map(|i| {
260 types::ref_type(&self.args[*i])
261 .parse::<TokenStream2>()
262 .unwrap()
263 });
264 let annotation: TokenStream2 = match user_fn.core_return_type.as_str() {
265 "T" | "T1" | "T2" | "T3" => format!(": Option<{}>", types::owned_type(&self.ret))
267 .parse()
268 .unwrap(),
269 _ => quote! {},
270 };
271 let ret_array_type = format_ident!("{}", types::array_type(&self.ret));
272 let builder_type = format_ident!("{}Builder", types::array_type(&self.ret));
273 let prebuilt_arg_type = match &self.prebuild {
274 Some(s) if optimize_const => s.split("::").next().unwrap().parse().unwrap(),
275 _ => quote! { () },
276 };
277 let prebuilt_arg_value = match &self.prebuild {
278 Some(s) => s
282 .replace('$', "i")
283 .parse()
284 .expect("invalid prebuild syntax"),
285 None => quote! { () },
286 };
287 let prebuild_const = if self.prebuild.is_some() && optimize_const {
288 let build_general = self.generate_build_scalar_function(user_fn, false)?;
289 quote! {{
290 let build_general = #build_general;
291 #(
292 let #prebuilt_inputs = match children[#prebuilt_indices].eval_const() {
294 Ok(s) => s,
295 Err(_) => return build_general(return_type, children),
297 };
298 let #prebuilt_inputs = match &#prebuilt_inputs {
300 Some(s) => s.as_scalar_ref_impl().try_into()?,
301 None => return Ok(Box::new(risingwave_expr::expr::LiteralExpression::new(
303 return_type,
304 None,
305 ))),
306 };
307 )*
308 #prebuilt_arg_value
309 }}
310 } else {
311 quote! { () }
312 };
313
314 let check_children = match variadic {
316 true => quote! { risingwave_expr::ensure!(children.len() >= #num_args); },
317 false => quote! { risingwave_expr::ensure!(children.len() == #num_args); },
318 };
319
320 let eval_variadic = variadic.then(|| {
322 quote! {
323 let mut columns = Vec::with_capacity(self.children.len() - #num_args);
324 for child in &self.children[#num_args..] {
325 columns.push(child.eval(input).await?);
326 }
327 let variadic_input = DataChunk::new(columns, input.visibility().clone());
328 }
329 });
330 let eval_row_variadic = variadic.then(|| {
332 quote! {
333 let mut row = Vec::with_capacity(self.children.len() - #num_args);
334 for child in &self.children[#num_args..] {
335 row.push(child.eval_row(input).await?);
336 }
337 let variadic_row = OwnedRow::new(row);
338 }
339 });
340
341 let generic = (self.ret == "boolean" && user_fn.generic == 3).then(|| {
342 let compatible_type = types::ref_type(types::min_compatible_type(&self.args))
344 .parse::<TokenStream2>()
345 .unwrap();
346 quote! { ::<_, _, #compatible_type> }
347 });
348 let prebuilt_arg = match (&self.prebuild, optimize_const) {
349 (Some(_), true) => quote! { &self.prebuilt_arg, },
351 (Some(_), false) => quote! { &#prebuilt_arg_value, },
353 (None, _) => quote! {},
355 };
356 let variadic_args = variadic.then(|| quote! { &variadic_row, });
357 let context = user_fn.context.then(|| quote! { &self.context, });
358 let writer = user_fn
359 .writer_type_kind
360 .is_some()
361 .then(|| quote! { &mut writer, });
362 let await_ = user_fn.async_.then(|| quote! { .await });
363
364 let record_error = {
365 #[allow(clippy::disallowed_methods)] let inputs_args = inputs
368 .iter()
369 .zip(user_fn.args_option.iter())
370 .map(|(input, opt)| {
371 if *opt {
372 quote! { #input.map(|s| ScalarRefImpl::from(s)) }
373 } else {
374 quote! { Some(ScalarRefImpl::from(#input)) }
375 }
376 });
377 let inputs_args = quote! {
378 let args: &[DatumRef<'_>] = &[#(#inputs_args),*];
379 let args = args.iter().copied();
380 };
381 let var_args = variadic.then(|| {
382 quote! {
383 let args = args.chain(variadic_row.iter());
384 }
385 });
386
387 quote! {
388 #inputs_args
389 #var_args
390 errors.push(ExprError::function(
391 stringify!(#fn_name),
392 args,
393 e,
394 ));
395 }
396 };
397
398 let mut output = quote! { #fn_name #generic(
401 #(#non_prebuilt_inputs,)*
402 #variadic_args
403 #prebuilt_arg
404 #context
405 #writer
406 ) #await_ };
407 output = match user_fn.return_type_kind {
410 _ if self.ret == "void" => quote! { { #output; Option::<i32>::None } },
412 ReturnTypeKind::T => quote! { Some(#output) },
413 ReturnTypeKind::Option => output,
414 ReturnTypeKind::Result => quote! {
415 match #output {
416 Ok(x) => Some(x),
417 Err(e) => {
418 #record_error
419 None
420 }
421 }
422 },
423 ReturnTypeKind::ResultOption => quote! {
424 match #output {
425 Ok(x) => x,
426 Err(e) => {
427 #record_error
428 None
429 }
430 }
431 },
432 };
433 if self.prebuild.is_some() {
436 output = quote! {
437 match (#(#inputs,)*) {
438 (#(Some(#inputs),)*) => #output,
439 _ => None,
440 }
441 };
442 } else {
443 #[allow(clippy::disallowed_methods)] let some_inputs = inputs
445 .iter()
446 .zip(user_fn.args_option.iter())
447 .map(|(input, opt)| {
448 if *opt {
449 quote! { #input }
450 } else {
451 quote! { Some(#input) }
452 }
453 });
454 output = quote! {
455 match (#(#inputs,)*) {
456 (#(#some_inputs,)*) => #output,
457 _ => None,
458 }
459 };
460 };
461 let append_output = match user_fn.writer_type_kind {
463 Some(WriterTypeKind::FmtWrite)
464 | Some(WriterTypeKind::IoWrite)
465 | Some(WriterTypeKind::ListWrite) => quote! {{
466 let mut writer = builder.writer();
467 if #output.is_some() {
468 writer.finish();
469 } else {
470 writer.rollback();
471 builder.append_null();
472 }
473 }},
474 Some(WriterTypeKind::JsonbbBuilder) => quote! {{
475 let mut writer_wrapper = builder.writer();
476 let mut writer = writer_wrapper.inner();
477 if #output.is_some() {
478 writer_wrapper.finish();
479 } else {
480 writer_wrapper.rollback();
481 builder.append_null();
482 }
483 }},
484 None if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
485 builder.append(#output.as_ref().map(|s| s.as_ref()));
486 },
487 None => quote! {
488 let output #annotation = #output;
489 builder.append(output.as_ref().map(|s| s.as_scalar_ref()));
490 },
491 };
492 let row_output = match user_fn.writer_type_kind {
494 Some(WriterTypeKind::FmtWrite) => quote! {{
495 let mut writer = String::new();
496 #output.map(|_| writer.into())
497 }},
498 Some(WriterTypeKind::IoWrite) => quote! {{
499 let mut writer = Vec::new();
500 #output.map(|_| writer.into())
501 }},
502 Some(WriterTypeKind::JsonbbBuilder) => quote! {{
503 let mut writer = jsonbb::Builder::<Vec<u8>>::new();
504 #output.map(|_| JsonbVal::from(writer.finish()).into())
505 }},
506 Some(WriterTypeKind::ListWrite) => quote! {{
507 let mut writer = {
508 let DataType::List(list_ty) = &self.context.return_type else {
509 panic!("data type must be DataType::List");
510 };
511 list_ty.elem().create_array_builder(1)
512 };
513 #output.map(|_| ListValue::new(writer.finish()).into())
514 }},
515 None if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
516 #output.map(|s| s.as_ref().into())
517 },
518 None => quote! {{
519 let output #annotation = #output;
520 output.map(|s| s.into())
521 }},
522 };
523 let eval = if let Some(batch_fn) = &self.batch_fn {
525 assert!(
526 !variadic,
527 "customized batch function is not supported for variadic functions"
528 );
529 let fn_name = format_ident!("{}", batch_fn);
531 quote! {
532 let c = #fn_name(#(#arrays),*);
533 Arc::new(c.into())
534 }
535 } else if (types::is_primitive(&self.ret) || self.ret == "boolean")
536 && user_fn.is_pure()
537 && !variadic
538 && self.prebuild.is_none()
539 {
540 match self.args.len() {
542 0 => quote! {
543 let c = #ret_array_type::from_iter_bitmap(
544 std::iter::repeat_with(|| #fn_name()).take(input.capacity()),
545 Bitmap::ones(input.capacity()),
546 );
547 Arc::new(c.into())
548 },
549 1 => quote! {
550 let c = #ret_array_type::from_iter_bitmap(
551 a0.raw_iter().map(|a| #fn_name(a)),
552 a0.null_bitmap().clone()
553 );
554 Arc::new(c.into())
555 },
556 2 => quote! {
557 #[allow(clippy::disallowed_methods)]
559 let c = #ret_array_type::from_iter_bitmap(
560 a0.raw_iter()
561 .zip(a1.raw_iter())
562 .map(|(a, b)| #fn_name #generic(a, b)),
563 a0.null_bitmap() & a1.null_bitmap(),
564 );
565 Arc::new(c.into())
566 },
567 n => todo!("SIMD optimization for {n} arguments"),
568 }
569 } else {
570 let let_variadic = variadic.then(|| {
572 quote! {
573 let variadic_row = variadic_input.row_at_unchecked_vis(i);
574 }
575 });
576 quote! {
577 let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone());
578
579 if input.is_vis_compacted() {
580 for i in 0..input.capacity() {
581 #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
582 #let_variadic
583 #append_output
584 }
585 } else {
586 for i in 0..input.capacity() {
587 if unsafe { !input.visibility().is_set_unchecked(i) } {
588 builder.append_null();
589 continue;
590 }
591 #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
592 #let_variadic
593 #append_output
594 }
595 }
596 Arc::new(builder.finish().into())
597 }
598 };
599
600 Ok(quote! {
601 |return_type: DataType, children: Vec<risingwave_expr::expr::BoxedExpression>|
602 -> risingwave_expr::Result<risingwave_expr::expr::BoxedExpression>
603 {
604 use std::sync::Arc;
605 use risingwave_common::array::*;
606 use risingwave_common::types::*;
607 use risingwave_common::bitmap::Bitmap;
608 use risingwave_common::row::OwnedRow;
609 use risingwave_common::util::iter_util::ZipEqFast;
610
611 use risingwave_expr::expr::{Context, BoxedExpression};
612 use risingwave_expr::{ExprError, Result};
613 use risingwave_expr::codegen::*;
614
615 #check_children
616 let prebuilt_arg = #prebuild_const;
617 let context = Context {
618 return_type,
619 arg_types: children.iter().map(|c| c.return_type()).collect(),
620 variadic: #variadic,
621 };
622
623 #[derive(Debug)]
624 struct #struct_name {
625 context: Context,
626 children: Vec<BoxedExpression>,
627 prebuilt_arg: #prebuilt_arg_type,
628 }
629 #[async_trait]
630 impl risingwave_expr::expr::Expression for #struct_name {
631 fn return_type(&self) -> DataType {
632 self.context.return_type.clone()
633 }
634 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
635 #(
636 let #array_refs = self.children[#children_indices].eval(input).await?;
637 let #arrays: &#arg_arrays = #array_refs.as_ref().into();
638 )*
639 #eval_variadic
640 let mut errors = vec![];
641 let array = { #eval };
642 if errors.is_empty() {
643 Ok(array)
644 } else {
645 Err(ExprError::Multiple(array, errors.into()))
646 }
647 }
648 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
649 #(
650 let #datums = self.children[#children_indices].eval_row(input).await?;
651 let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap());
652 )*
653 #eval_row_variadic
654 let mut errors: Vec<ExprError> = vec![];
655 let output = #row_output;
656 if let Some(err) = errors.into_iter().next() {
657 Err(err.into())
658 } else {
659 Ok(output)
660 }
661 }
662 }
663
664 Ok(Box::new(#struct_name {
665 context,
666 children,
667 prebuilt_arg,
668 }))
669 }
670 })
671 }
672
673 pub fn generate_aggregate_descriptor(
679 &self,
680 user_fn: &AggregateFnOrImpl,
681 build_fn: bool,
682 ) -> Result<TokenStream2> {
683 let name = self.name.clone();
684
685 let mut args = Vec::with_capacity(self.args.len());
686 for ty in &self.args {
687 args.push(sig_data_type(ty));
688 }
689 let ret = sig_data_type(&self.ret);
690 let state_type = match &self.state {
691 Some(ty) if ty != "ref" => {
692 let ty = data_type(ty);
693 quote! { Some(#ty) }
694 }
695 _ => quote! { None },
696 };
697 let append_only = match build_fn {
698 false => !user_fn.has_retract(),
699 true => self.append_only,
700 };
701
702 let pb_kind = format_ident!("{}", utils::to_camel_case(&name));
703 let ctor_name = match append_only {
704 false => format_ident!("{}", self.ident_name()),
705 true => format_ident!("{}_append_only", self.ident_name()),
706 };
707 let build_fn = if build_fn {
708 let name = format_ident!("{}", user_fn.as_fn().name);
709 quote! { #name }
710 } else if self.rewritten {
711 quote! { |_| Err(ExprError::UnsupportedFunction(#name.into())) }
712 } else {
713 self.generate_agg_build_fn(user_fn)?
714 };
715 let build_retractable = match append_only {
716 true => quote! { None },
717 false => quote! { Some(#build_fn) },
718 };
719 let build_append_only = match append_only {
720 false => quote! { None },
721 true => quote! { Some(#build_fn) },
722 };
723 let retractable_state_type = match append_only {
724 true => quote! { None },
725 false => state_type.clone(),
726 };
727 let append_only_state_type = match append_only {
728 false => quote! { None },
729 true => state_type,
730 };
731 let type_infer_fn = self.generate_type_infer_fn()?;
732 let deprecated = self.deprecated;
733 let maybe_allow_deprecated = if deprecated {
734 quote! { #[allow(deprecated)] }
735 } else {
736 quote! {}
737 };
738
739 Ok(quote! {
740 #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
741 fn #ctor_name() -> risingwave_expr::sig::FuncSign {
742 use risingwave_common::types::{DataType, DataTypeName};
743 use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
744
745 #maybe_allow_deprecated
746 FuncSign {
747 name: risingwave_pb::expr::agg_call::PbKind::#pb_kind.into(),
748 inputs_type: vec![#(#args),*],
749 variadic: false,
750 ret_type: #ret,
751 build: FuncBuilder::Aggregate {
752 retractable: #build_retractable,
753 append_only: #build_append_only,
754 retractable_state_type: #retractable_state_type,
755 append_only_state_type: #append_only_state_type,
756 },
757 type_infer: #type_infer_fn,
758 deprecated: #deprecated,
759 }
760 }
761 })
762 }
763
764 fn generate_agg_build_fn(&self, user_fn: &AggregateFnOrImpl) -> Result<TokenStream2> {
766 let custom_state = user_fn.accumulate().first_mut_ref_arg.as_ref();
769 let state_type: TokenStream2 = match (custom_state, &self.state) {
770 (Some(s), _) => s.parse().unwrap(),
771 (_, Some(state)) if state == "ref" => types::ref_type(&self.ret).parse().unwrap(),
772 (_, Some(state)) if state != "ref" => types::owned_type(state).parse().unwrap(),
773 _ => types::owned_type(&self.ret).parse().unwrap(),
774 };
775 let let_arrays = self
776 .args
777 .iter()
778 .enumerate()
779 .map(|(i, arg)| {
780 let array = format_ident!("a{i}");
781 let array_type: TokenStream2 = types::array_type(arg).parse().unwrap();
782 quote! {
783 let #array: &#array_type = input.column_at(#i).as_ref().into();
784 }
785 })
786 .collect_vec();
787 let let_values = (0..self.args.len())
788 .map(|i| {
789 let v = format_ident!("v{i}");
790 let a = format_ident!("a{i}");
791 quote! { let #v = unsafe { #a.value_at_unchecked(row_id) }; }
792 })
793 .collect_vec();
794 let downcast_state = if custom_state.is_some() {
795 quote! { let mut state: &mut #state_type = state0.downcast_mut(); }
796 } else if let Some(s) = &self.state
797 && s == "ref"
798 {
799 quote! { let mut state: Option<#state_type> = state0.as_datum_mut().as_ref().map(|x| x.as_scalar_ref_impl().try_into().unwrap()); }
800 } else {
801 quote! { let mut state: Option<#state_type> = state0.as_datum_mut().take().map(|s| s.try_into().unwrap()); }
802 };
803 let restore_state = if custom_state.is_some() {
804 quote! {}
805 } else if let Some(s) = &self.state
806 && s == "ref"
807 {
808 quote! { *state0.as_datum_mut() = state.map(|x| x.to_owned_scalar().into()); }
809 } else {
810 quote! { *state0.as_datum_mut() = state.map(|s| s.into()); }
811 };
812 let create_state = if custom_state.is_some() {
813 quote! {
814 fn create_state(&self) -> Result<AggregateState> {
815 Ok(AggregateState::Any(Box::<#state_type>::default()))
816 }
817 }
818 } else if let Some(state) = &self.init_state {
819 let state: TokenStream2 = state.parse().unwrap();
820 quote! {
821 fn create_state(&self) -> Result<AggregateState> {
822 Ok(AggregateState::Datum(Some(#state.into())))
823 }
824 }
825 } else {
826 quote! {}
828 };
829 let args = (0..self.args.len()).map(|i| format_ident!("v{i}"));
830 let args = quote! { #(#args,)* };
831 let panic_on_retract = {
832 let msg = format!(
833 "attempt to retract on aggregate function {}, but it is append-only",
834 self.name
835 );
836 quote! { assert_eq!(op, Op::Insert, #msg); }
837 };
838 let mut next_state = match user_fn {
839 AggregateFnOrImpl::Fn(f) => {
840 let context = f.context.then(|| quote! { &self.context, });
841 let fn_name = format_ident!("{}", f.name);
842 match f.retract {
843 true => {
844 quote! { #fn_name(state, #args matches!(op, Op::Delete | Op::UpdateDelete) #context) }
845 }
846 false => quote! {{
847 #panic_on_retract
848 #fn_name(state, #args #context)
849 }},
850 }
851 }
852 AggregateFnOrImpl::Impl(i) => {
853 let retract = match i.retract {
854 Some(_) => quote! { self.function.retract(state, #args) },
855 None => panic_on_retract,
856 };
857 quote! {
858 if matches!(op, Op::Delete | Op::UpdateDelete) {
859 #retract
860 } else {
861 self.function.accumulate(state, #args)
862 }
863 }
864 }
865 };
866 next_state = match user_fn.accumulate().return_type_kind {
867 ReturnTypeKind::T => quote! { Some(#next_state) },
868 ReturnTypeKind::Option => next_state,
869 ReturnTypeKind::Result => quote! { Some(#next_state?) },
870 ReturnTypeKind::ResultOption => quote! { #next_state? },
871 };
872 if user_fn.accumulate().args_option.iter().all(|b| !b) {
873 match self.args.len() {
874 0 => {
875 next_state = quote! {
876 match state {
877 Some(state) => #next_state,
878 None => state,
879 }
880 };
881 }
882 1 => {
883 let first_state = if self.init_state.is_some() {
884 quote! { unreachable!() }
886 } else if let Some(s) = &self.state
887 && s == "ref"
888 {
889 quote! { Some(v0) }
891 } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
892 && impl_.create_state.is_some()
893 {
894 quote! {{
896 let state = self.function.create_state();
897 #next_state
898 }}
899 } else {
900 quote! {{
901 let state = #state_type::default();
902 #next_state
903 }}
904 };
905 next_state = quote! {
906 match (state, v0) {
907 (Some(state), Some(v0)) => #next_state,
908 (None, Some(v0)) => #first_state,
909 (state, None) => state,
910 }
911 };
912 }
913 _ => todo!("multiple arguments are not supported for non-option function"),
914 }
915 }
916 let update_state = if custom_state.is_some() {
917 quote! { _ = #next_state; }
918 } else {
919 quote! { state = #next_state; }
920 };
921 let get_result = if custom_state.is_some() {
922 quote! { Ok(state.downcast_ref::<#state_type>().into()) }
923 } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
924 && impl_.finalize.is_some()
925 {
926 quote! {
927 let state = match state.as_datum() {
928 Some(s) => s.as_scalar_ref_impl().try_into().unwrap(),
929 None => return Ok(None),
930 };
931 Ok(Some(self.function.finalize(state).into()))
932 }
933 } else {
934 quote! { Ok(state.as_datum().clone()) }
935 };
936 let function_field = match user_fn {
937 AggregateFnOrImpl::Fn(_) => quote! {},
938 AggregateFnOrImpl::Impl(i) => {
939 let struct_name = format_ident!("{}", i.struct_name);
940 let generic = self.generic.as_ref().map(|g| {
941 let g = format_ident!("{g}");
942 quote! { <#g> }
943 });
944 quote! { function: #struct_name #generic, }
945 }
946 };
947 let function_new = match user_fn {
948 AggregateFnOrImpl::Fn(_) => quote! {},
949 AggregateFnOrImpl::Impl(i) => {
950 let struct_name = format_ident!("{}", i.struct_name);
951 let generic = self.generic.as_ref().map(|g| {
952 let g = format_ident!("{g}");
953 quote! { ::<#g> }
954 });
955 quote! { function: #struct_name #generic :: default(), }
956 }
957 };
958
959 Ok(quote! {
960 |agg| {
961 use std::collections::HashSet;
962 use std::ops::Range;
963 use risingwave_common::array::*;
964 use risingwave_common::types::*;
965 use risingwave_common::bail;
966 use risingwave_common::bitmap::Bitmap;
967 use risingwave_common_estimate_size::EstimateSize;
968
969 use risingwave_expr::expr::Context;
970 use risingwave_expr::Result;
971 use risingwave_expr::aggregate::AggregateState;
972 use risingwave_expr::codegen::async_trait;
973
974 let context = Context {
975 return_type: agg.return_type.clone(),
976 arg_types: agg.args.arg_types().to_owned(),
977 variadic: false,
978 };
979
980 struct Agg {
981 context: Context,
982 #function_field
983 }
984
985 #[async_trait]
986 impl risingwave_expr::aggregate::AggregateFunction for Agg {
987 fn return_type(&self) -> DataType {
988 self.context.return_type.clone()
989 }
990
991 #create_state
992
993 async fn update(&self, state0: &mut AggregateState, input: &StreamChunk) -> Result<()> {
994 #(#let_arrays)*
995 #downcast_state
996 for row_id in input.visibility().iter_ones() {
997 let op = unsafe { *input.ops().get_unchecked(row_id) };
998 #(#let_values)*
999 #update_state
1000 }
1001 #restore_state
1002 Ok(())
1003 }
1004
1005 async fn update_range(&self, state0: &mut AggregateState, input: &StreamChunk, range: Range<usize>) -> Result<()> {
1006 assert!(range.end <= input.capacity());
1007 #(#let_arrays)*
1008 #downcast_state
1009 if input.is_vis_compacted() {
1010 for row_id in range {
1011 let op = unsafe { *input.ops().get_unchecked(row_id) };
1012 #(#let_values)*
1013 #update_state
1014 }
1015 } else {
1016 for row_id in input.visibility().iter_ones() {
1017 if row_id < range.start {
1018 continue;
1019 } else if row_id >= range.end {
1020 break;
1021 }
1022 let op = unsafe { *input.ops().get_unchecked(row_id) };
1023 #(#let_values)*
1024 #update_state
1025 }
1026 }
1027 #restore_state
1028 Ok(())
1029 }
1030
1031 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
1032 #get_result
1033 }
1034 }
1035
1036 Ok(Box::new(Agg {
1037 context,
1038 #function_new
1039 }))
1040 }
1041 })
1042 }
1043
1044 fn generate_table_function_descriptor(
1048 &self,
1049 user_fn: &UserFunctionAttr,
1050 build_fn: bool,
1051 ) -> Result<TokenStream2> {
1052 let name = self.name.clone();
1053 let mut args = Vec::with_capacity(self.args.len());
1054 for ty in &self.args {
1055 args.push(sig_data_type(ty));
1056 }
1057 let ret = sig_data_type(&self.ret);
1058
1059 let pb_type = format_ident!("{}", utils::to_camel_case(&name));
1060 let ctor_name = format_ident!("{}", self.ident_name());
1061 let build_fn = if build_fn {
1062 let name = format_ident!("{}", user_fn.name);
1063 quote! { #name }
1064 } else if self.rewritten {
1065 quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
1066 } else {
1067 self.generate_build_table_function(user_fn)?
1068 };
1069 let type_infer_fn = self.generate_type_infer_fn()?;
1070 let deprecated = self.deprecated;
1071 let maybe_allow_deprecated = if deprecated {
1072 quote! { #[allow(deprecated)] }
1073 } else {
1074 quote! {}
1075 };
1076
1077 Ok(quote! {
1078 #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
1079 fn #ctor_name() -> risingwave_expr::sig::FuncSign {
1080 use risingwave_common::types::{DataType, DataTypeName};
1081 use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
1082
1083 #maybe_allow_deprecated
1084 FuncSign {
1085 name: risingwave_pb::expr::table_function::Type::#pb_type.into(),
1086 inputs_type: vec![#(#args),*],
1087 variadic: false,
1088 ret_type: #ret,
1089 build: FuncBuilder::Table(#build_fn),
1090 type_infer: #type_infer_fn,
1091 deprecated: #deprecated,
1092 }
1093 }
1094 })
1095 }
1096
1097 fn generate_build_table_function(&self, user_fn: &UserFunctionAttr) -> Result<TokenStream2> {
1098 let num_args = self.args.len();
1099 let return_types = output_types(&self.ret);
1100 let fn_name = format_ident!("{}", user_fn.name);
1101 let struct_name = format_ident!("{}", utils::to_camel_case(&self.ident_name()));
1102 let arg_ids = (0..num_args)
1103 .filter(|i| match &self.prebuild {
1104 Some(s) => !s.contains(&format!("${i}")),
1105 None => true,
1106 })
1107 .collect_vec();
1108 let const_ids = (0..num_args).filter(|i| match &self.prebuild {
1109 Some(s) => s.contains(&format!("${i}")),
1110 None => false,
1111 });
1112 let inputs: Vec<_> = arg_ids.iter().map(|i| format_ident!("i{i}")).collect();
1113 let all_child: Vec<_> = (0..num_args).map(|i| format_ident!("child{i}")).collect();
1114 let const_child: Vec<_> = const_ids.map(|i| format_ident!("child{i}")).collect();
1115 let child: Vec<_> = arg_ids.iter().map(|i| format_ident!("child{i}")).collect();
1116 let array_refs: Vec<_> = arg_ids.iter().map(|i| format_ident!("array{i}")).collect();
1117 let arrays: Vec<_> = arg_ids.iter().map(|i| format_ident!("a{i}")).collect();
1118 let arg_arrays = arg_ids
1119 .iter()
1120 .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
1121 let outputs = (0..return_types.len())
1122 .map(|i| format_ident!("o{i}"))
1123 .collect_vec();
1124 let builders = (0..return_types.len())
1125 .map(|i| format_ident!("builder{i}"))
1126 .collect_vec();
1127 let builder_types = return_types
1128 .iter()
1129 .map(|ty| format_ident!("{}Builder", types::array_type(ty)))
1130 .collect_vec();
1131 let return_types = if return_types.len() == 1 {
1132 vec![quote! { self.context.return_type.clone() }]
1133 } else {
1134 (0..return_types.len())
1135 .map(|i| quote! { self.context.return_type.as_struct().types().nth(#i).unwrap().clone() })
1136 .collect()
1137 };
1138 #[allow(clippy::disallowed_methods)]
1139 let optioned_outputs = user_fn
1140 .core_return_type
1141 .split(',')
1142 .map(|t| t.contains("Option"))
1143 .zip(&outputs)
1145 .map(|(optional, o)| match optional {
1146 false => quote! { Some(#o.as_scalar_ref()) },
1147 true => quote! { #o.map(|o| o.as_scalar_ref()) },
1148 })
1149 .collect_vec();
1150 let build_value_array = if return_types.len() == 1 {
1151 quote! { let [value_array] = value_arrays; }
1152 } else {
1153 quote! {
1154 let value_array = StructArray::new(
1155 self.context.return_type.as_struct().clone(),
1156 value_arrays.to_vec(),
1157 Bitmap::ones(len),
1158 ).into_ref();
1159 }
1160 };
1161 let context = user_fn.context.then(|| quote! { &self.context, });
1162 let prebuilt_arg = match &self.prebuild {
1163 Some(_) => quote! { &self.prebuilt_arg, },
1164 None => quote! {},
1165 };
1166 let prebuilt_arg_type = match &self.prebuild {
1167 Some(s) => s.split("::").next().unwrap().parse().unwrap(),
1168 None => quote! { () },
1169 };
1170 let prebuilt_arg_value = match &self.prebuild {
1171 Some(s) => s
1172 .replace('$', "child")
1173 .parse()
1174 .expect("invalid prebuild syntax"),
1175 None => quote! { () },
1176 };
1177 let iter = quote! { #fn_name(#(#inputs,)* #prebuilt_arg #context) };
1178 let mut iter = match user_fn.return_type_kind {
1179 ReturnTypeKind::T => quote! { #iter },
1180 ReturnTypeKind::Option => quote! { match #iter {
1181 Some(it) => it,
1182 None => continue,
1183 } },
1184 ReturnTypeKind::Result => quote! { match #iter {
1185 Ok(it) => it,
1186 Err(e) => {
1187 index_builder.append(Some(i as i32));
1188 #(#builders.append_null();)*
1189 error_builder.append_display(Some(e.as_report()));
1190 continue;
1191 }
1192 } },
1193 ReturnTypeKind::ResultOption => quote! { match #iter {
1194 Ok(Some(it)) => it,
1195 Ok(None) => continue,
1196 Err(e) => {
1197 index_builder.append(Some(i as i32));
1198 #(#builders.append_null();)*
1199 error_builder.append_display(Some(e.as_report()));
1200 continue;
1201 }
1202 } },
1203 };
1204 #[allow(clippy::disallowed_methods)] let some_inputs = inputs
1208 .iter()
1209 .zip(user_fn.args_option.iter())
1210 .map(|(input, opt)| {
1211 if *opt {
1212 quote! { #input }
1213 } else {
1214 quote! { Some(#input) }
1215 }
1216 });
1217 iter = quote! {
1218 match (#(#inputs,)*) {
1219 (#(#some_inputs,)*) => #iter,
1220 _ => continue,
1221 }
1222 };
1223 let iterator_item_type = user_fn.iterator_item_kind.clone().ok_or_else(|| {
1224 Error::new(
1225 user_fn.return_type_span,
1226 "expect `impl Iterator` in return type",
1227 )
1228 })?;
1229 let append_output = match iterator_item_type {
1230 ReturnTypeKind::T => quote! {
1231 let (#(#outputs),*) = output;
1232 #(#builders.append(#optioned_outputs);)* error_builder.append_null();
1233 },
1234 ReturnTypeKind::Option => quote! { match output {
1235 Some((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1236 None => { #(#builders.append_null();)* error_builder.append_null(); }
1237 } },
1238 ReturnTypeKind::Result => quote! { match output {
1239 Ok((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1240 Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1241 } },
1242 ReturnTypeKind::ResultOption => quote! { match output {
1243 Ok(Some((#(#outputs),*))) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1244 Ok(None) => { #(#builders.append_null();)* error_builder.append_null(); }
1245 Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1246 } },
1247 };
1248
1249 Ok(quote! {
1250 |return_type, chunk_size, children| {
1251 use risingwave_common::array::*;
1252 use risingwave_common::types::*;
1253 use risingwave_common::bitmap::Bitmap;
1254 use risingwave_common::util::iter_util::ZipEqFast;
1255 use risingwave_expr::expr::{BoxedExpression, Context};
1256 use risingwave_expr::{Result, ExprError};
1257 use risingwave_expr::codegen::*;
1258
1259 risingwave_expr::ensure!(children.len() == #num_args);
1260
1261 let context = Context {
1262 return_type: return_type.clone(),
1263 arg_types: children.iter().map(|c| c.return_type()).collect(),
1264 variadic: false,
1265 };
1266
1267 let mut iter = children.into_iter();
1268 #(let #all_child = iter.next().unwrap();)*
1269 #(
1270 let #const_child = #const_child.eval_const()?;
1271 let #const_child = match &#const_child {
1272 Some(s) => s.as_scalar_ref_impl().try_into()?,
1273 None => return Ok(risingwave_expr::table_function::empty(return_type)),
1275 };
1276 )*
1277
1278 #[derive(Debug)]
1279 struct #struct_name {
1280 context: Context,
1281 chunk_size: usize,
1282 #(#child: BoxedExpression,)*
1283 prebuilt_arg: #prebuilt_arg_type,
1284 }
1285 #[async_trait]
1286 impl risingwave_expr::table_function::TableFunction for #struct_name {
1287 fn return_type(&self) -> DataType {
1288 self.context.return_type.clone()
1289 }
1290 async fn eval<'a>(&'a self, input: &'a DataChunk) -> BoxStream<'a, Result<DataChunk>> {
1291 self.eval_inner(input)
1292 }
1293 }
1294 impl #struct_name {
1295 #[try_stream(boxed, ok = DataChunk, error = ExprError)]
1296 async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
1297 #(
1298 let #array_refs = self.#child.eval(input).await?;
1299 let #arrays: &#arg_arrays = #array_refs.as_ref().into();
1300 )*
1301
1302 let mut index_builder = I32ArrayBuilder::new(self.chunk_size);
1303 #(let mut #builders = #builder_types::with_type(self.chunk_size, #return_types);)*
1304 let mut error_builder = Utf8ArrayBuilder::new(self.chunk_size);
1305
1306 for i in 0..input.capacity() {
1307 if unsafe { !input.visibility().is_set_unchecked(i) } {
1308 continue;
1309 }
1310 #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
1311 for output in #iter {
1312 index_builder.append(Some(i as i32));
1313 #append_output
1314
1315 if index_builder.len() == self.chunk_size {
1316 let len = index_builder.len();
1317 let index_array = std::mem::replace(&mut index_builder, I32ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1318 let value_arrays = [#(std::mem::replace(&mut #builders, #builder_types::with_type(self.chunk_size, #return_types)).finish().into_ref()),*];
1319 #build_value_array
1320 let error_array = std::mem::replace(&mut error_builder, Utf8ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1321 if error_array.null_bitmap().any() {
1322 yield DataChunk::new(vec![index_array, value_array, error_array], self.chunk_size);
1323 } else {
1324 yield DataChunk::new(vec![index_array, value_array], self.chunk_size);
1325 }
1326 }
1327 }
1328 }
1329
1330 if index_builder.len() > 0 {
1331 let len = index_builder.len();
1332 let index_array = index_builder.finish().into_ref();
1333 let value_arrays = [#(#builders.finish().into_ref()),*];
1334 #build_value_array
1335 let error_array = error_builder.finish().into_ref();
1336 if error_array.null_bitmap().any() {
1337 yield DataChunk::new(vec![index_array, value_array, error_array], len);
1338 } else {
1339 yield DataChunk::new(vec![index_array, value_array], len);
1340 }
1341 }
1342 }
1343 }
1344
1345 Ok(Box::new(#struct_name {
1346 context,
1347 chunk_size,
1348 #(#child,)*
1349 prebuilt_arg: #prebuilt_arg_value,
1350 }))
1351 }
1352 })
1353 }
1354}
1355
1356fn sig_data_type(ty: &str) -> TokenStream2 {
1357 match ty {
1358 "any" => quote! { SigDataType::Any },
1359 "anyarray" => quote! { SigDataType::AnyArray },
1360 "anymap" => quote! { SigDataType::AnyMap },
1361 "vector" => quote! { SigDataType::Vector },
1362 "struct" => quote! { SigDataType::AnyStruct },
1363 _ if ty.starts_with("struct") && ty.contains("any") => quote! { SigDataType::AnyStruct },
1364 _ => {
1365 let datatype = data_type(ty);
1366 quote! { SigDataType::Exact(#datatype) }
1367 }
1368 }
1369}
1370
1371fn data_type(ty: &str) -> TokenStream2 {
1372 if let Some(ty) = ty.strip_suffix("[]") {
1373 let inner_type = data_type(ty);
1374 return quote! { DataType::list(#inner_type) };
1375 }
1376 if ty.starts_with("struct<") {
1377 return quote! { DataType::Struct(#ty.parse().expect("invalid struct type")) };
1378 }
1379 let variant = format_ident!("{}", types::data_type(ty));
1380 quote! { DataType::#variant }
1387}
1388
1389fn output_types(ty: &str) -> Vec<&str> {
1396 if let Some(s) = ty.strip_prefix("struct<")
1397 && let Some(args) = s.strip_suffix('>')
1398 {
1399 args.split(',')
1400 .map(|s| s.split_whitespace().nth(1).unwrap())
1401 .collect()
1402 } else {
1403 vec![ty]
1404 }
1405}