1use itertools::Itertools;
18use proc_macro2::{Ident, Span};
19use quote::{format_ident, quote};
20
21use super::*;
22
23impl FunctionAttr {
24 pub fn expand(&self) -> Vec<Self> {
26 if self
28 .args
29 .last()
30 .is_some_and(|arg| arg.starts_with("variadic"))
31 {
32 let mut attrs = Vec::new();
36 attrs.extend(
37 FunctionAttr {
38 args: {
39 let mut args = self.args.clone();
40 *args.last_mut().unwrap() = "...".to_owned();
41 args
42 },
43 ..self.clone()
44 }
45 .expand(),
46 );
47 attrs.extend(
48 FunctionAttr {
49 name: format!("{}_variadic", self.name),
50 args: {
51 let mut args = self.args.clone();
52 let last = args.last_mut().unwrap();
53 *last = last.strip_prefix("variadic ").unwrap().into();
54 args
55 },
56 ..self.clone()
57 }
58 .expand(),
59 );
60 return attrs;
61 }
62 let args = self.args.iter().map(|ty| types::expand_type_wildcard(ty));
63 let ret = types::expand_type_wildcard(&self.ret);
64 let mut attrs = Vec::new();
65 for (args, mut ret) in args.multi_cartesian_product().cartesian_product(ret) {
66 if ret == "auto" {
67 ret = types::min_compatible_type(&args);
68 }
69 let attr = FunctionAttr {
70 args: args.iter().map(|s| s.to_string()).collect(),
71 ret: ret.to_owned(),
72 ..self.clone()
73 };
74 attrs.push(attr);
75 }
76 attrs
77 }
78
79 fn generate_type_infer_fn(&self) -> Result<TokenStream2> {
81 if let Some(func) = &self.type_infer {
82 if func == "unreachable" {
83 return Ok(
84 quote! { |_| unreachable!("type inference for this function should be specially handled in frontend, and should not call sig.type_infer") },
85 );
86 }
87 return Ok(func.parse().unwrap());
89 } else if self.ret == "any" {
90 if let Some(i) = self.args.iter().position(|t| t == "any") {
92 return Ok(quote! { |args| Ok(args[#i].clone()) });
94 }
95 if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
96 return Ok(quote! { |args| Ok(args[#i].as_list_element_type().clone()) });
98 }
99 } else if self.ret == "anyarray" {
100 if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
101 return Ok(quote! { |args| Ok(args[#i].clone()) });
103 }
104 if let Some(i) = self.args.iter().position(|t| t == "any") {
105 return Ok(quote! { |args| Ok(DataType::List(Box::new(args[#i].clone()))) });
107 }
108 } else if self.ret == "struct" {
109 if let Some(i) = self.args.iter().position(|t| t == "struct") {
110 return Ok(quote! { |args| Ok(args[#i].clone()) });
112 }
113 } else if self.ret == "anymap" {
114 if let Some(i) = self.args.iter().position(|t| t == "anymap") {
115 return Ok(quote! { |args| Ok(args[#i].clone()) });
117 }
118 } else if self.ret == "vector" {
119 if let Some(i) = self.args.iter().position(|t| t == "vector") {
120 return Ok(quote! { |args| Ok(args[#i].clone()) });
123 }
124 } else {
125 let ty = data_type(&self.ret);
127 return Ok(quote! { |_| Ok(#ty) });
128 }
129 Err(Error::new(
130 Span::call_site(),
131 "type inference function cannot be automatically derived. You should provide: `type_infer = \"|args| Ok(...)\"`",
132 ))
133 }
134
135 pub fn generate_function_descriptor(
143 &self,
144 user_fn: &UserFunctionAttr,
145 build_fn: bool,
146 ) -> Result<TokenStream2> {
147 if self.is_table_function {
148 return self.generate_table_function_descriptor(user_fn, build_fn);
149 }
150 let name = self.name.clone();
151 let variadic = matches!(self.args.last(), Some(t) if t == "...");
152 let args = match variadic {
153 true => &self.args[..self.args.len() - 1],
154 false => &self.args[..],
155 }
156 .iter()
157 .map(|ty| sig_data_type(ty))
158 .collect_vec();
159 let ret = sig_data_type(&self.ret);
160
161 let pb_type = format_ident!("{}", utils::to_camel_case(&name));
162 let ctor_name = format_ident!("{}", self.ident_name());
163 let build_fn = if build_fn {
164 let name = format_ident!("{}", user_fn.name);
165 quote! { #name }
166 } else if self.rewritten {
167 quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
168 } else {
169 self.generate_build_scalar_function(user_fn, true)?
171 };
172 let type_infer_fn = self.generate_type_infer_fn()?;
173 let deprecated = self.deprecated;
174
175 Ok(quote! {
176 #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
177 fn #ctor_name() -> risingwave_expr::sig::FuncSign {
178 use risingwave_common::types::{DataType, DataTypeName};
179 use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
180
181 FuncSign {
182 name: risingwave_pb::expr::expr_node::Type::#pb_type.into(),
183 inputs_type: vec![#(#args),*],
184 variadic: #variadic,
185 ret_type: #ret,
186 build: FuncBuilder::Scalar(#build_fn),
187 type_infer: #type_infer_fn,
188 deprecated: #deprecated,
189 }
190 }
191 })
192 }
193
194 fn generate_build_scalar_function(
199 &self,
200 user_fn: &UserFunctionAttr,
201 optimize_const: bool,
202 ) -> Result<TokenStream2> {
203 let variadic = matches!(self.args.last(), Some(t) if t == "...");
204 let num_args = self.args.len() - if variadic { 1 } else { 0 };
205 let fn_name = format_ident!("{}", user_fn.name);
206 let struct_name = match optimize_const {
207 true => format_ident!("{}OptimizeConst", utils::to_camel_case(&self.ident_name())),
208 false => format_ident!("{}", utils::to_camel_case(&self.ident_name())),
209 };
210
211 let prebuilt_indices = match &self.prebuild {
226 Some(s) => (0..num_args)
227 .filter(|i| s.contains(&format!("${i}")))
228 .collect_vec(),
229 None => vec![],
230 };
231 let non_prebuilt_indices = match &self.prebuild {
232 Some(s) => (0..num_args)
233 .filter(|i| !s.contains(&format!("${i}")))
234 .collect_vec(),
235 _ => (0..num_args).collect_vec(),
236 };
237 let children_indices = match optimize_const {
238 #[allow(clippy::redundant_clone)] true => non_prebuilt_indices.clone(),
240 false => (0..num_args).collect_vec(),
241 };
242
243 fn idents(prefix: &str, indices: &[usize]) -> Vec<Ident> {
245 indices
246 .iter()
247 .map(|i| format_ident!("{prefix}{i}"))
248 .collect()
249 }
250 let inputs = idents("i", &children_indices);
251 let prebuilt_inputs = idents("i", &prebuilt_indices);
252 let non_prebuilt_inputs = idents("i", &non_prebuilt_indices);
253 let array_refs = idents("array", &children_indices);
254 let arrays = idents("a", &children_indices);
255 let datums = idents("v", &children_indices);
256 let arg_arrays = children_indices
257 .iter()
258 .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
259 let arg_types = children_indices.iter().map(|i| {
260 types::ref_type(&self.args[*i])
261 .parse::<TokenStream2>()
262 .unwrap()
263 });
264 let annotation: TokenStream2 = match user_fn.core_return_type.as_str() {
265 "T" | "T1" | "T2" | "T3" => format!(": Option<{}>", types::owned_type(&self.ret))
267 .parse()
268 .unwrap(),
269 _ => quote! {},
270 };
271 let ret_array_type = format_ident!("{}", types::array_type(&self.ret));
272 let builder_type = format_ident!("{}Builder", types::array_type(&self.ret));
273 let prebuilt_arg_type = match &self.prebuild {
274 Some(s) if optimize_const => s.split("::").next().unwrap().parse().unwrap(),
275 _ => quote! { () },
276 };
277 let prebuilt_arg_value = match &self.prebuild {
278 Some(s) => s
282 .replace('$', "i")
283 .parse()
284 .expect("invalid prebuild syntax"),
285 None => quote! { () },
286 };
287 let prebuild_const = if self.prebuild.is_some() && optimize_const {
288 let build_general = self.generate_build_scalar_function(user_fn, false)?;
289 quote! {{
290 let build_general = #build_general;
291 #(
292 let #prebuilt_inputs = match children[#prebuilt_indices].eval_const() {
294 Ok(s) => s,
295 Err(_) => return build_general(return_type, children),
297 };
298 let #prebuilt_inputs = match &#prebuilt_inputs {
300 Some(s) => s.as_scalar_ref_impl().try_into()?,
301 None => return Ok(Box::new(risingwave_expr::expr::LiteralExpression::new(
303 return_type,
304 None,
305 ))),
306 };
307 )*
308 #prebuilt_arg_value
309 }}
310 } else {
311 quote! { () }
312 };
313
314 let check_children = match variadic {
316 true => quote! { risingwave_expr::ensure!(children.len() >= #num_args); },
317 false => quote! { risingwave_expr::ensure!(children.len() == #num_args); },
318 };
319
320 let eval_variadic = variadic.then(|| {
322 quote! {
323 let mut columns = Vec::with_capacity(self.children.len() - #num_args);
324 for child in &self.children[#num_args..] {
325 columns.push(child.eval(input).await?);
326 }
327 let variadic_input = DataChunk::new(columns, input.visibility().clone());
328 }
329 });
330 let eval_row_variadic = variadic.then(|| {
332 quote! {
333 let mut row = Vec::with_capacity(self.children.len() - #num_args);
334 for child in &self.children[#num_args..] {
335 row.push(child.eval_row(input).await?);
336 }
337 let variadic_row = OwnedRow::new(row);
338 }
339 });
340
341 let generic = (self.ret == "boolean" && user_fn.generic == 3).then(|| {
342 let compatible_type = types::ref_type(types::min_compatible_type(&self.args))
344 .parse::<TokenStream2>()
345 .unwrap();
346 quote! { ::<_, _, #compatible_type> }
347 });
348 let prebuilt_arg = match (&self.prebuild, optimize_const) {
349 (Some(_), true) => quote! { &self.prebuilt_arg, },
351 (Some(_), false) => quote! { &#prebuilt_arg_value, },
353 (None, _) => quote! {},
355 };
356 let variadic_args = variadic.then(|| quote! { &variadic_row, });
357 let context = user_fn.context.then(|| quote! { &self.context, });
358 let writer = user_fn.write.then(|| quote! { &mut writer, });
359 let await_ = user_fn.async_.then(|| quote! { .await });
360
361 let record_error = {
362 #[allow(clippy::disallowed_methods)] let inputs_args = inputs
365 .iter()
366 .zip(user_fn.args_option.iter())
367 .map(|(input, opt)| {
368 if *opt {
369 quote! { #input.map(|s| ScalarRefImpl::from(s)) }
370 } else {
371 quote! { Some(ScalarRefImpl::from(#input)) }
372 }
373 });
374 let inputs_args = quote! {
375 let args: &[DatumRef<'_>] = &[#(#inputs_args),*];
376 let args = args.iter().copied();
377 };
378 let var_args = variadic.then(|| {
379 quote! {
380 let args = args.chain(variadic_row.iter());
381 }
382 });
383
384 quote! {
385 #inputs_args
386 #var_args
387 errors.push(ExprError::function(
388 stringify!(#fn_name),
389 args,
390 e,
391 ));
392 }
393 };
394
395 let mut output = quote! { #fn_name #generic(
398 #(#non_prebuilt_inputs,)*
399 #variadic_args
400 #prebuilt_arg
401 #context
402 #writer
403 ) #await_ };
404 output = match user_fn.return_type_kind {
407 _ if self.ret == "void" => quote! { { #output; Option::<i32>::None } },
409 ReturnTypeKind::T => quote! { Some(#output) },
410 ReturnTypeKind::Option => output,
411 ReturnTypeKind::Result => quote! {
412 match #output {
413 Ok(x) => Some(x),
414 Err(e) => {
415 #record_error
416 None
417 }
418 }
419 },
420 ReturnTypeKind::ResultOption => quote! {
421 match #output {
422 Ok(x) => x,
423 Err(e) => {
424 #record_error
425 None
426 }
427 }
428 },
429 };
430 if self.prebuild.is_some() {
433 output = quote! {
434 match (#(#inputs,)*) {
435 (#(Some(#inputs),)*) => #output,
436 _ => None,
437 }
438 };
439 } else {
440 #[allow(clippy::disallowed_methods)] let some_inputs = inputs
442 .iter()
443 .zip(user_fn.args_option.iter())
444 .map(|(input, opt)| {
445 if *opt {
446 quote! { #input }
447 } else {
448 quote! { Some(#input) }
449 }
450 });
451 output = quote! {
452 match (#(#inputs,)*) {
453 (#(#some_inputs,)*) => #output,
454 _ => None,
455 }
456 };
457 };
458 let append_output = match user_fn.write {
460 true => quote! {{
461 let mut writer = builder.writer().begin();
462 if #output.is_some() {
463 writer.finish();
464 } else {
465 drop(writer);
466 builder.append_null();
467 }
468 }},
469 false if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
470 builder.append(#output.as_ref().map(|s| s.as_ref()));
471 },
472 false => quote! {
473 let output #annotation = #output;
474 builder.append(output.as_ref().map(|s| s.as_scalar_ref()));
475 },
476 };
477 let row_output = match user_fn.write {
479 true => quote! {{
480 let mut writer = String::new();
481 #output.map(|_| writer.into())
482 }},
483 false if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
484 #output.map(|s| s.as_ref().into())
485 },
486 false => quote! {{
487 let output #annotation = #output;
488 output.map(|s| s.into())
489 }},
490 };
491 let eval = if let Some(batch_fn) = &self.batch_fn {
493 assert!(
494 !variadic,
495 "customized batch function is not supported for variadic functions"
496 );
497 let fn_name = format_ident!("{}", batch_fn);
499 quote! {
500 let c = #fn_name(#(#arrays),*);
501 Arc::new(c.into())
502 }
503 } else if (types::is_primitive(&self.ret) || self.ret == "boolean")
504 && user_fn.is_pure()
505 && !variadic
506 && self.prebuild.is_none()
507 {
508 match self.args.len() {
510 0 => quote! {
511 let c = #ret_array_type::from_iter_bitmap(
512 std::iter::repeat_with(|| #fn_name()).take(input.capacity())
513 Bitmap::ones(input.capacity()),
514 );
515 Arc::new(c.into())
516 },
517 1 => quote! {
518 let c = #ret_array_type::from_iter_bitmap(
519 a0.raw_iter().map(|a| #fn_name(a)),
520 a0.null_bitmap().clone()
521 );
522 Arc::new(c.into())
523 },
524 2 => quote! {
525 #[allow(clippy::disallowed_methods)]
527 let c = #ret_array_type::from_iter_bitmap(
528 a0.raw_iter()
529 .zip(a1.raw_iter())
530 .map(|(a, b)| #fn_name #generic(a, b)),
531 a0.null_bitmap() & a1.null_bitmap(),
532 );
533 Arc::new(c.into())
534 },
535 n => todo!("SIMD optimization for {n} arguments"),
536 }
537 } else {
538 let let_variadic = variadic.then(|| {
540 quote! {
541 let variadic_row = variadic_input.row_at_unchecked_vis(i);
542 }
543 });
544 quote! {
545 let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone());
546
547 if input.is_compacted() {
548 for i in 0..input.capacity() {
549 #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
550 #let_variadic
551 #append_output
552 }
553 } else {
554 for i in 0..input.capacity() {
555 if unsafe { !input.visibility().is_set_unchecked(i) } {
556 builder.append_null();
557 continue;
558 }
559 #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
560 #let_variadic
561 #append_output
562 }
563 }
564 Arc::new(builder.finish().into())
565 }
566 };
567
568 Ok(quote! {
569 |return_type: DataType, children: Vec<risingwave_expr::expr::BoxedExpression>|
570 -> risingwave_expr::Result<risingwave_expr::expr::BoxedExpression>
571 {
572 use std::sync::Arc;
573 use risingwave_common::array::*;
574 use risingwave_common::types::*;
575 use risingwave_common::bitmap::Bitmap;
576 use risingwave_common::row::OwnedRow;
577 use risingwave_common::util::iter_util::ZipEqFast;
578
579 use risingwave_expr::expr::{Context, BoxedExpression};
580 use risingwave_expr::{ExprError, Result};
581 use risingwave_expr::codegen::*;
582
583 #check_children
584 let prebuilt_arg = #prebuild_const;
585 let context = Context {
586 return_type,
587 arg_types: children.iter().map(|c| c.return_type()).collect(),
588 variadic: #variadic,
589 };
590
591 #[derive(Debug)]
592 struct #struct_name {
593 context: Context,
594 children: Vec<BoxedExpression>,
595 prebuilt_arg: #prebuilt_arg_type,
596 }
597 #[async_trait]
598 impl risingwave_expr::expr::Expression for #struct_name {
599 fn return_type(&self) -> DataType {
600 self.context.return_type.clone()
601 }
602 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
603 #(
604 let #array_refs = self.children[#children_indices].eval(input).await?;
605 let #arrays: &#arg_arrays = #array_refs.as_ref().into();
606 )*
607 #eval_variadic
608 let mut errors = vec![];
609 let array = { #eval };
610 if errors.is_empty() {
611 Ok(array)
612 } else {
613 Err(ExprError::Multiple(array, errors.into()))
614 }
615 }
616 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
617 #(
618 let #datums = self.children[#children_indices].eval_row(input).await?;
619 let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap());
620 )*
621 #eval_row_variadic
622 let mut errors: Vec<ExprError> = vec![];
623 let output = #row_output;
624 if let Some(err) = errors.into_iter().next() {
625 Err(err.into())
626 } else {
627 Ok(output)
628 }
629 }
630 }
631
632 Ok(Box::new(#struct_name {
633 context,
634 children,
635 prebuilt_arg,
636 }))
637 }
638 })
639 }
640
641 pub fn generate_aggregate_descriptor(
647 &self,
648 user_fn: &AggregateFnOrImpl,
649 build_fn: bool,
650 ) -> Result<TokenStream2> {
651 let name = self.name.clone();
652
653 let mut args = Vec::with_capacity(self.args.len());
654 for ty in &self.args {
655 args.push(sig_data_type(ty));
656 }
657 let ret = sig_data_type(&self.ret);
658 let state_type = match &self.state {
659 Some(ty) if ty != "ref" => {
660 let ty = data_type(ty);
661 quote! { Some(#ty) }
662 }
663 _ => quote! { None },
664 };
665 let append_only = match build_fn {
666 false => !user_fn.has_retract(),
667 true => self.append_only,
668 };
669
670 let pb_kind = format_ident!("{}", utils::to_camel_case(&name));
671 let ctor_name = match append_only {
672 false => format_ident!("{}", self.ident_name()),
673 true => format_ident!("{}_append_only", self.ident_name()),
674 };
675 let build_fn = if build_fn {
676 let name = format_ident!("{}", user_fn.as_fn().name);
677 quote! { #name }
678 } else if self.rewritten {
679 quote! { |_| Err(ExprError::UnsupportedFunction(#name.into())) }
680 } else {
681 self.generate_agg_build_fn(user_fn)?
682 };
683 let build_retractable = match append_only {
684 true => quote! { None },
685 false => quote! { Some(#build_fn) },
686 };
687 let build_append_only = match append_only {
688 false => quote! { None },
689 true => quote! { Some(#build_fn) },
690 };
691 let retractable_state_type = match append_only {
692 true => quote! { None },
693 false => state_type.clone(),
694 };
695 let append_only_state_type = match append_only {
696 false => quote! { None },
697 true => state_type,
698 };
699 let type_infer_fn = self.generate_type_infer_fn()?;
700 let deprecated = self.deprecated;
701
702 Ok(quote! {
703 #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
704 fn #ctor_name() -> risingwave_expr::sig::FuncSign {
705 use risingwave_common::types::{DataType, DataTypeName};
706 use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
707
708 FuncSign {
709 name: risingwave_pb::expr::agg_call::PbKind::#pb_kind.into(),
710 inputs_type: vec![#(#args),*],
711 variadic: false,
712 ret_type: #ret,
713 build: FuncBuilder::Aggregate {
714 retractable: #build_retractable,
715 append_only: #build_append_only,
716 retractable_state_type: #retractable_state_type,
717 append_only_state_type: #append_only_state_type,
718 },
719 type_infer: #type_infer_fn,
720 deprecated: #deprecated,
721 }
722 }
723 })
724 }
725
726 fn generate_agg_build_fn(&self, user_fn: &AggregateFnOrImpl) -> Result<TokenStream2> {
728 let custom_state = user_fn.accumulate().first_mut_ref_arg.as_ref();
731 let state_type: TokenStream2 = match (custom_state, &self.state) {
732 (Some(s), _) => s.parse().unwrap(),
733 (_, Some(state)) if state == "ref" => types::ref_type(&self.ret).parse().unwrap(),
734 (_, Some(state)) if state != "ref" => types::owned_type(state).parse().unwrap(),
735 _ => types::owned_type(&self.ret).parse().unwrap(),
736 };
737 let let_arrays = self
738 .args
739 .iter()
740 .enumerate()
741 .map(|(i, arg)| {
742 let array = format_ident!("a{i}");
743 let array_type: TokenStream2 = types::array_type(arg).parse().unwrap();
744 quote! {
745 let #array: &#array_type = input.column_at(#i).as_ref().into();
746 }
747 })
748 .collect_vec();
749 let let_values = (0..self.args.len())
750 .map(|i| {
751 let v = format_ident!("v{i}");
752 let a = format_ident!("a{i}");
753 quote! { let #v = unsafe { #a.value_at_unchecked(row_id) }; }
754 })
755 .collect_vec();
756 let downcast_state = if custom_state.is_some() {
757 quote! { let mut state: &mut #state_type = state0.downcast_mut(); }
758 } else if let Some(s) = &self.state
759 && s == "ref"
760 {
761 quote! { let mut state: Option<#state_type> = state0.as_datum_mut().as_ref().map(|x| x.as_scalar_ref_impl().try_into().unwrap()); }
762 } else {
763 quote! { let mut state: Option<#state_type> = state0.as_datum_mut().take().map(|s| s.try_into().unwrap()); }
764 };
765 let restore_state = if custom_state.is_some() {
766 quote! {}
767 } else if let Some(s) = &self.state
768 && s == "ref"
769 {
770 quote! { *state0.as_datum_mut() = state.map(|x| x.to_owned_scalar().into()); }
771 } else {
772 quote! { *state0.as_datum_mut() = state.map(|s| s.into()); }
773 };
774 let create_state = if custom_state.is_some() {
775 quote! {
776 fn create_state(&self) -> Result<AggregateState> {
777 Ok(AggregateState::Any(Box::<#state_type>::default()))
778 }
779 }
780 } else if let Some(state) = &self.init_state {
781 let state: TokenStream2 = state.parse().unwrap();
782 quote! {
783 fn create_state(&self) -> Result<AggregateState> {
784 Ok(AggregateState::Datum(Some(#state.into())))
785 }
786 }
787 } else {
788 quote! {}
790 };
791 let args = (0..self.args.len()).map(|i| format_ident!("v{i}"));
792 let args = quote! { #(#args,)* };
793 let panic_on_retract = {
794 let msg = format!(
795 "attempt to retract on aggregate function {}, but it is append-only",
796 self.name
797 );
798 quote! { assert_eq!(op, Op::Insert, #msg); }
799 };
800 let mut next_state = match user_fn {
801 AggregateFnOrImpl::Fn(f) => {
802 let context = f.context.then(|| quote! { &self.context, });
803 let fn_name = format_ident!("{}", f.name);
804 match f.retract {
805 true => {
806 quote! { #fn_name(state, #args matches!(op, Op::Delete | Op::UpdateDelete) #context) }
807 }
808 false => quote! {{
809 #panic_on_retract
810 #fn_name(state, #args #context)
811 }},
812 }
813 }
814 AggregateFnOrImpl::Impl(i) => {
815 let retract = match i.retract {
816 Some(_) => quote! { self.function.retract(state, #args) },
817 None => panic_on_retract,
818 };
819 quote! {
820 if matches!(op, Op::Delete | Op::UpdateDelete) {
821 #retract
822 } else {
823 self.function.accumulate(state, #args)
824 }
825 }
826 }
827 };
828 next_state = match user_fn.accumulate().return_type_kind {
829 ReturnTypeKind::T => quote! { Some(#next_state) },
830 ReturnTypeKind::Option => next_state,
831 ReturnTypeKind::Result => quote! { Some(#next_state?) },
832 ReturnTypeKind::ResultOption => quote! { #next_state? },
833 };
834 if user_fn.accumulate().args_option.iter().all(|b| !b) {
835 match self.args.len() {
836 0 => {
837 next_state = quote! {
838 match state {
839 Some(state) => #next_state,
840 None => state,
841 }
842 };
843 }
844 1 => {
845 let first_state = if self.init_state.is_some() {
846 quote! { unreachable!() }
848 } else if let Some(s) = &self.state
849 && s == "ref"
850 {
851 quote! { Some(v0) }
853 } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
854 && impl_.create_state.is_some()
855 {
856 quote! {{
858 let state = self.function.create_state();
859 #next_state
860 }}
861 } else {
862 quote! {{
863 let state = #state_type::default();
864 #next_state
865 }}
866 };
867 next_state = quote! {
868 match (state, v0) {
869 (Some(state), Some(v0)) => #next_state,
870 (None, Some(v0)) => #first_state,
871 (state, None) => state,
872 }
873 };
874 }
875 _ => todo!("multiple arguments are not supported for non-option function"),
876 }
877 }
878 let update_state = if custom_state.is_some() {
879 quote! { _ = #next_state; }
880 } else {
881 quote! { state = #next_state; }
882 };
883 let get_result = if custom_state.is_some() {
884 quote! { Ok(state.downcast_ref::<#state_type>().into()) }
885 } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
886 && impl_.finalize.is_some()
887 {
888 quote! {
889 let state = match state.as_datum() {
890 Some(s) => s.as_scalar_ref_impl().try_into().unwrap(),
891 None => return Ok(None),
892 };
893 Ok(Some(self.function.finalize(state).into()))
894 }
895 } else {
896 quote! { Ok(state.as_datum().clone()) }
897 };
898 let function_field = match user_fn {
899 AggregateFnOrImpl::Fn(_) => quote! {},
900 AggregateFnOrImpl::Impl(i) => {
901 let struct_name = format_ident!("{}", i.struct_name);
902 let generic = self.generic.as_ref().map(|g| {
903 let g = format_ident!("{g}");
904 quote! { <#g> }
905 });
906 quote! { function: #struct_name #generic, }
907 }
908 };
909 let function_new = match user_fn {
910 AggregateFnOrImpl::Fn(_) => quote! {},
911 AggregateFnOrImpl::Impl(i) => {
912 let struct_name = format_ident!("{}", i.struct_name);
913 let generic = self.generic.as_ref().map(|g| {
914 let g = format_ident!("{g}");
915 quote! { ::<#g> }
916 });
917 quote! { function: #struct_name #generic :: default(), }
918 }
919 };
920
921 Ok(quote! {
922 |agg| {
923 use std::collections::HashSet;
924 use std::ops::Range;
925 use risingwave_common::array::*;
926 use risingwave_common::types::*;
927 use risingwave_common::bail;
928 use risingwave_common::bitmap::Bitmap;
929 use risingwave_common_estimate_size::EstimateSize;
930
931 use risingwave_expr::expr::Context;
932 use risingwave_expr::Result;
933 use risingwave_expr::aggregate::AggregateState;
934 use risingwave_expr::codegen::async_trait;
935
936 let context = Context {
937 return_type: agg.return_type.clone(),
938 arg_types: agg.args.arg_types().to_owned(),
939 variadic: false,
940 };
941
942 struct Agg {
943 context: Context,
944 #function_field
945 }
946
947 #[async_trait]
948 impl risingwave_expr::aggregate::AggregateFunction for Agg {
949 fn return_type(&self) -> DataType {
950 self.context.return_type.clone()
951 }
952
953 #create_state
954
955 async fn update(&self, state0: &mut AggregateState, input: &StreamChunk) -> Result<()> {
956 #(#let_arrays)*
957 #downcast_state
958 for row_id in input.visibility().iter_ones() {
959 let op = unsafe { *input.ops().get_unchecked(row_id) };
960 #(#let_values)*
961 #update_state
962 }
963 #restore_state
964 Ok(())
965 }
966
967 async fn update_range(&self, state0: &mut AggregateState, input: &StreamChunk, range: Range<usize>) -> Result<()> {
968 assert!(range.end <= input.capacity());
969 #(#let_arrays)*
970 #downcast_state
971 if input.is_compacted() {
972 for row_id in range {
973 let op = unsafe { *input.ops().get_unchecked(row_id) };
974 #(#let_values)*
975 #update_state
976 }
977 } else {
978 for row_id in input.visibility().iter_ones() {
979 if row_id < range.start {
980 continue;
981 } else if row_id >= range.end {
982 break;
983 }
984 let op = unsafe { *input.ops().get_unchecked(row_id) };
985 #(#let_values)*
986 #update_state
987 }
988 }
989 #restore_state
990 Ok(())
991 }
992
993 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
994 #get_result
995 }
996 }
997
998 Ok(Box::new(Agg {
999 context,
1000 #function_new
1001 }))
1002 }
1003 })
1004 }
1005
1006 fn generate_table_function_descriptor(
1010 &self,
1011 user_fn: &UserFunctionAttr,
1012 build_fn: bool,
1013 ) -> Result<TokenStream2> {
1014 let name = self.name.clone();
1015 let mut args = Vec::with_capacity(self.args.len());
1016 for ty in &self.args {
1017 args.push(sig_data_type(ty));
1018 }
1019 let ret = sig_data_type(&self.ret);
1020
1021 let pb_type = format_ident!("{}", utils::to_camel_case(&name));
1022 let ctor_name = format_ident!("{}", self.ident_name());
1023 let build_fn = if build_fn {
1024 let name = format_ident!("{}", user_fn.name);
1025 quote! { #name }
1026 } else if self.rewritten {
1027 quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
1028 } else {
1029 self.generate_build_table_function(user_fn)?
1030 };
1031 let type_infer_fn = self.generate_type_infer_fn()?;
1032 let deprecated = self.deprecated;
1033
1034 Ok(quote! {
1035 #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
1036 fn #ctor_name() -> risingwave_expr::sig::FuncSign {
1037 use risingwave_common::types::{DataType, DataTypeName};
1038 use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
1039
1040 FuncSign {
1041 name: risingwave_pb::expr::table_function::Type::#pb_type.into(),
1042 inputs_type: vec![#(#args),*],
1043 variadic: false,
1044 ret_type: #ret,
1045 build: FuncBuilder::Table(#build_fn),
1046 type_infer: #type_infer_fn,
1047 deprecated: #deprecated,
1048 }
1049 }
1050 })
1051 }
1052
1053 fn generate_build_table_function(&self, user_fn: &UserFunctionAttr) -> Result<TokenStream2> {
1054 let num_args = self.args.len();
1055 let return_types = output_types(&self.ret);
1056 let fn_name = format_ident!("{}", user_fn.name);
1057 let struct_name = format_ident!("{}", utils::to_camel_case(&self.ident_name()));
1058 let arg_ids = (0..num_args)
1059 .filter(|i| match &self.prebuild {
1060 Some(s) => !s.contains(&format!("${i}")),
1061 None => true,
1062 })
1063 .collect_vec();
1064 let const_ids = (0..num_args).filter(|i| match &self.prebuild {
1065 Some(s) => s.contains(&format!("${i}")),
1066 None => false,
1067 });
1068 let inputs: Vec<_> = arg_ids.iter().map(|i| format_ident!("i{i}")).collect();
1069 let all_child: Vec<_> = (0..num_args).map(|i| format_ident!("child{i}")).collect();
1070 let const_child: Vec<_> = const_ids.map(|i| format_ident!("child{i}")).collect();
1071 let child: Vec<_> = arg_ids.iter().map(|i| format_ident!("child{i}")).collect();
1072 let array_refs: Vec<_> = arg_ids.iter().map(|i| format_ident!("array{i}")).collect();
1073 let arrays: Vec<_> = arg_ids.iter().map(|i| format_ident!("a{i}")).collect();
1074 let arg_arrays = arg_ids
1075 .iter()
1076 .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
1077 let outputs = (0..return_types.len())
1078 .map(|i| format_ident!("o{i}"))
1079 .collect_vec();
1080 let builders = (0..return_types.len())
1081 .map(|i| format_ident!("builder{i}"))
1082 .collect_vec();
1083 let builder_types = return_types
1084 .iter()
1085 .map(|ty| format_ident!("{}Builder", types::array_type(ty)))
1086 .collect_vec();
1087 let return_types = if return_types.len() == 1 {
1088 vec![quote! { self.context.return_type.clone() }]
1089 } else {
1090 (0..return_types.len())
1091 .map(|i| quote! { self.context.return_type.as_struct().types().nth(#i).unwrap().clone() })
1092 .collect()
1093 };
1094 #[allow(clippy::disallowed_methods)]
1095 let optioned_outputs = user_fn
1096 .core_return_type
1097 .split(',')
1098 .map(|t| t.contains("Option"))
1099 .zip(&outputs)
1101 .map(|(optional, o)| match optional {
1102 false => quote! { Some(#o.as_scalar_ref()) },
1103 true => quote! { #o.map(|o| o.as_scalar_ref()) },
1104 })
1105 .collect_vec();
1106 let build_value_array = if return_types.len() == 1 {
1107 quote! { let [value_array] = value_arrays; }
1108 } else {
1109 quote! {
1110 let value_array = StructArray::new(
1111 self.context.return_type.as_struct().clone(),
1112 value_arrays.to_vec(),
1113 Bitmap::ones(len),
1114 ).into_ref();
1115 }
1116 };
1117 let context = user_fn.context.then(|| quote! { &self.context, });
1118 let prebuilt_arg = match &self.prebuild {
1119 Some(_) => quote! { &self.prebuilt_arg, },
1120 None => quote! {},
1121 };
1122 let prebuilt_arg_type = match &self.prebuild {
1123 Some(s) => s.split("::").next().unwrap().parse().unwrap(),
1124 None => quote! { () },
1125 };
1126 let prebuilt_arg_value = match &self.prebuild {
1127 Some(s) => s
1128 .replace('$', "child")
1129 .parse()
1130 .expect("invalid prebuild syntax"),
1131 None => quote! { () },
1132 };
1133 let iter = quote! { #fn_name(#(#inputs,)* #prebuilt_arg #context) };
1134 let mut iter = match user_fn.return_type_kind {
1135 ReturnTypeKind::T => quote! { #iter },
1136 ReturnTypeKind::Option => quote! { match #iter {
1137 Some(it) => it,
1138 None => continue,
1139 } },
1140 ReturnTypeKind::Result => quote! { match #iter {
1141 Ok(it) => it,
1142 Err(e) => {
1143 index_builder.append(Some(i as i32));
1144 #(#builders.append_null();)*
1145 error_builder.append_display(Some(e.as_report()));
1146 continue;
1147 }
1148 } },
1149 ReturnTypeKind::ResultOption => quote! { match #iter {
1150 Ok(Some(it)) => it,
1151 Ok(None) => continue,
1152 Err(e) => {
1153 index_builder.append(Some(i as i32));
1154 #(#builders.append_null();)*
1155 error_builder.append_display(Some(e.as_report()));
1156 continue;
1157 }
1158 } },
1159 };
1160 #[allow(clippy::disallowed_methods)] let some_inputs = inputs
1164 .iter()
1165 .zip(user_fn.args_option.iter())
1166 .map(|(input, opt)| {
1167 if *opt {
1168 quote! { #input }
1169 } else {
1170 quote! { Some(#input) }
1171 }
1172 });
1173 iter = quote! {
1174 match (#(#inputs,)*) {
1175 (#(#some_inputs,)*) => #iter,
1176 _ => continue,
1177 }
1178 };
1179 let iterator_item_type = user_fn.iterator_item_kind.clone().ok_or_else(|| {
1180 Error::new(
1181 user_fn.return_type_span,
1182 "expect `impl Iterator` in return type",
1183 )
1184 })?;
1185 let append_output = match iterator_item_type {
1186 ReturnTypeKind::T => quote! {
1187 let (#(#outputs),*) = output;
1188 #(#builders.append(#optioned_outputs);)* error_builder.append_null();
1189 },
1190 ReturnTypeKind::Option => quote! { match output {
1191 Some((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1192 None => { #(#builders.append_null();)* error_builder.append_null(); }
1193 } },
1194 ReturnTypeKind::Result => quote! { match output {
1195 Ok((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1196 Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1197 } },
1198 ReturnTypeKind::ResultOption => quote! { match output {
1199 Ok(Some((#(#outputs),*))) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1200 Ok(None) => { #(#builders.append_null();)* error_builder.append_null(); }
1201 Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1202 } },
1203 };
1204
1205 Ok(quote! {
1206 |return_type, chunk_size, children| {
1207 use risingwave_common::array::*;
1208 use risingwave_common::types::*;
1209 use risingwave_common::bitmap::Bitmap;
1210 use risingwave_common::util::iter_util::ZipEqFast;
1211 use risingwave_expr::expr::{BoxedExpression, Context};
1212 use risingwave_expr::{Result, ExprError};
1213 use risingwave_expr::codegen::*;
1214
1215 risingwave_expr::ensure!(children.len() == #num_args);
1216
1217 let context = Context {
1218 return_type: return_type.clone(),
1219 arg_types: children.iter().map(|c| c.return_type()).collect(),
1220 variadic: false,
1221 };
1222
1223 let mut iter = children.into_iter();
1224 #(let #all_child = iter.next().unwrap();)*
1225 #(
1226 let #const_child = #const_child.eval_const()?;
1227 let #const_child = match &#const_child {
1228 Some(s) => s.as_scalar_ref_impl().try_into()?,
1229 None => return Ok(risingwave_expr::table_function::empty(return_type)),
1231 };
1232 )*
1233
1234 #[derive(Debug)]
1235 struct #struct_name {
1236 context: Context,
1237 chunk_size: usize,
1238 #(#child: BoxedExpression,)*
1239 prebuilt_arg: #prebuilt_arg_type,
1240 }
1241 #[async_trait]
1242 impl risingwave_expr::table_function::TableFunction for #struct_name {
1243 fn return_type(&self) -> DataType {
1244 self.context.return_type.clone()
1245 }
1246 async fn eval<'a>(&'a self, input: &'a DataChunk) -> BoxStream<'a, Result<DataChunk>> {
1247 self.eval_inner(input)
1248 }
1249 }
1250 impl #struct_name {
1251 #[try_stream(boxed, ok = DataChunk, error = ExprError)]
1252 async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
1253 #(
1254 let #array_refs = self.#child.eval(input).await?;
1255 let #arrays: &#arg_arrays = #array_refs.as_ref().into();
1256 )*
1257
1258 let mut index_builder = I32ArrayBuilder::new(self.chunk_size);
1259 #(let mut #builders = #builder_types::with_type(self.chunk_size, #return_types);)*
1260 let mut error_builder = Utf8ArrayBuilder::new(self.chunk_size);
1261
1262 for i in 0..input.capacity() {
1263 if unsafe { !input.visibility().is_set_unchecked(i) } {
1264 continue;
1265 }
1266 #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
1267 for output in #iter {
1268 index_builder.append(Some(i as i32));
1269 #append_output
1270
1271 if index_builder.len() == self.chunk_size {
1272 let len = index_builder.len();
1273 let index_array = std::mem::replace(&mut index_builder, I32ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1274 let value_arrays = [#(std::mem::replace(&mut #builders, #builder_types::with_type(self.chunk_size, #return_types)).finish().into_ref()),*];
1275 #build_value_array
1276 let error_array = std::mem::replace(&mut error_builder, Utf8ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1277 if error_array.null_bitmap().any() {
1278 yield DataChunk::new(vec![index_array, value_array, error_array], self.chunk_size);
1279 } else {
1280 yield DataChunk::new(vec![index_array, value_array], self.chunk_size);
1281 }
1282 }
1283 }
1284 }
1285
1286 if index_builder.len() > 0 {
1287 let len = index_builder.len();
1288 let index_array = index_builder.finish().into_ref();
1289 let value_arrays = [#(#builders.finish().into_ref()),*];
1290 #build_value_array
1291 let error_array = error_builder.finish().into_ref();
1292 if error_array.null_bitmap().any() {
1293 yield DataChunk::new(vec![index_array, value_array, error_array], len);
1294 } else {
1295 yield DataChunk::new(vec![index_array, value_array], len);
1296 }
1297 }
1298 }
1299 }
1300
1301 Ok(Box::new(#struct_name {
1302 context,
1303 chunk_size,
1304 #(#child,)*
1305 prebuilt_arg: #prebuilt_arg_value,
1306 }))
1307 }
1308 })
1309 }
1310}
1311
1312fn sig_data_type(ty: &str) -> TokenStream2 {
1313 match ty {
1314 "any" => quote! { SigDataType::Any },
1315 "anyarray" => quote! { SigDataType::AnyArray },
1316 "anymap" => quote! { SigDataType::AnyMap },
1317 "vector" => quote! { SigDataType::Vector },
1318 "struct" => quote! { SigDataType::AnyStruct },
1319 _ if ty.starts_with("struct") && ty.contains("any") => quote! { SigDataType::AnyStruct },
1320 _ => {
1321 let datatype = data_type(ty);
1322 quote! { SigDataType::Exact(#datatype) }
1323 }
1324 }
1325}
1326
1327fn data_type(ty: &str) -> TokenStream2 {
1328 if let Some(ty) = ty.strip_suffix("[]") {
1329 let inner_type = data_type(ty);
1330 return quote! { DataType::List(Box::new(#inner_type)) };
1331 }
1332 if ty.starts_with("struct<") {
1333 return quote! { DataType::Struct(#ty.parse().expect("invalid struct type")) };
1334 }
1335 let variant = format_ident!("{}", types::data_type(ty));
1336 quote! { DataType::#variant }
1343}
1344
1345fn output_types(ty: &str) -> Vec<&str> {
1352 if let Some(s) = ty.strip_prefix("struct<")
1353 && let Some(args) = s.strip_suffix('>')
1354 {
1355 args.split(',')
1356 .map(|s| s.split_whitespace().nth(1).unwrap())
1357 .collect()
1358 } else {
1359 vec![ty]
1360 }
1361}