1use itertools::Itertools;
18use proc_macro2::{Ident, Span};
19use quote::{format_ident, quote};
20
21use super::*;
22
23impl FunctionAttr {
24 pub fn expand(&self) -> Vec<Self> {
26 if self
28 .args
29 .last()
30 .is_some_and(|arg| arg.starts_with("variadic"))
31 {
32 let mut attrs = Vec::new();
36 attrs.extend(
37 FunctionAttr {
38 args: {
39 let mut args = self.args.clone();
40 *args.last_mut().unwrap() = "...".to_owned();
41 args
42 },
43 ..self.clone()
44 }
45 .expand(),
46 );
47 attrs.extend(
48 FunctionAttr {
49 name: format!("{}_variadic", self.name),
50 args: {
51 let mut args = self.args.clone();
52 let last = args.last_mut().unwrap();
53 *last = last.strip_prefix("variadic ").unwrap().into();
54 args
55 },
56 ..self.clone()
57 }
58 .expand(),
59 );
60 return attrs;
61 }
62 let args = self.args.iter().map(|ty| types::expand_type_wildcard(ty));
63 let ret = types::expand_type_wildcard(&self.ret);
64 let mut attrs = Vec::new();
65 for (args, mut ret) in args.multi_cartesian_product().cartesian_product(ret) {
66 if ret == "auto" {
67 ret = types::min_compatible_type(&args);
68 }
69 let attr = FunctionAttr {
70 args: args.iter().map(|s| s.to_string()).collect(),
71 ret: ret.to_owned(),
72 ..self.clone()
73 };
74 attrs.push(attr);
75 }
76 attrs
77 }
78
79 fn generate_type_infer_fn(&self) -> Result<TokenStream2> {
81 if let Some(func) = &self.type_infer {
82 if func == "unreachable" {
83 return Ok(
84 quote! { |_| unreachable!("type inference for this function should be specially handled in frontend, and should not call sig.type_infer") },
85 );
86 }
87 return Ok(func.parse().unwrap());
89 } else if self.ret == "any" {
90 if let Some(i) = self.args.iter().position(|t| t == "any") {
92 return Ok(quote! { |args| Ok(args[#i].clone()) });
94 }
95 if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
96 return Ok(quote! { |args| Ok(args[#i].as_list().clone()) });
98 }
99 } else if self.ret == "anyarray" {
100 if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
101 return Ok(quote! { |args| Ok(args[#i].clone()) });
103 }
104 if let Some(i) = self.args.iter().position(|t| t == "any") {
105 return Ok(quote! { |args| Ok(DataType::List(Box::new(args[#i].clone()))) });
107 }
108 } else if self.ret == "struct" {
109 if let Some(i) = self.args.iter().position(|t| t == "struct") {
110 return Ok(quote! { |args| Ok(args[#i].clone()) });
112 }
113 } else if self.ret == "anymap" {
114 if let Some(i) = self.args.iter().position(|t| t == "anymap") {
115 return Ok(quote! { |args| Ok(args[#i].clone()) });
117 }
118 } else {
119 let ty = data_type(&self.ret);
121 return Ok(quote! { |_| Ok(#ty) });
122 }
123 Err(Error::new(
124 Span::call_site(),
125 "type inference function cannot be automatically derived. You should provide: `type_infer = \"|args| Ok(...)\"`",
126 ))
127 }
128
129 pub fn generate_function_descriptor(
137 &self,
138 user_fn: &UserFunctionAttr,
139 build_fn: bool,
140 ) -> Result<TokenStream2> {
141 if self.is_table_function {
142 return self.generate_table_function_descriptor(user_fn, build_fn);
143 }
144 let name = self.name.clone();
145 let variadic = matches!(self.args.last(), Some(t) if t == "...");
146 let args = match variadic {
147 true => &self.args[..self.args.len() - 1],
148 false => &self.args[..],
149 }
150 .iter()
151 .map(|ty| sig_data_type(ty))
152 .collect_vec();
153 let ret = sig_data_type(&self.ret);
154
155 let pb_type = format_ident!("{}", utils::to_camel_case(&name));
156 let ctor_name = format_ident!("{}", self.ident_name());
157 let build_fn = if build_fn {
158 let name = format_ident!("{}", user_fn.name);
159 quote! { #name }
160 } else if self.rewritten {
161 quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
162 } else {
163 self.generate_build_scalar_function(user_fn, true)?
165 };
166 let type_infer_fn = self.generate_type_infer_fn()?;
167 let deprecated = self.deprecated;
168
169 Ok(quote! {
170 #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
171 fn #ctor_name() -> risingwave_expr::sig::FuncSign {
172 use risingwave_common::types::{DataType, DataTypeName};
173 use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
174
175 FuncSign {
176 name: risingwave_pb::expr::expr_node::Type::#pb_type.into(),
177 inputs_type: vec![#(#args),*],
178 variadic: #variadic,
179 ret_type: #ret,
180 build: FuncBuilder::Scalar(#build_fn),
181 type_infer: #type_infer_fn,
182 deprecated: #deprecated,
183 }
184 }
185 })
186 }
187
188 fn generate_build_scalar_function(
193 &self,
194 user_fn: &UserFunctionAttr,
195 optimize_const: bool,
196 ) -> Result<TokenStream2> {
197 let variadic = matches!(self.args.last(), Some(t) if t == "...");
198 let num_args = self.args.len() - if variadic { 1 } else { 0 };
199 let fn_name = format_ident!("{}", user_fn.name);
200 let struct_name = match optimize_const {
201 true => format_ident!("{}OptimizeConst", utils::to_camel_case(&self.ident_name())),
202 false => format_ident!("{}", utils::to_camel_case(&self.ident_name())),
203 };
204
205 let prebuilt_indices = match &self.prebuild {
220 Some(s) => (0..num_args)
221 .filter(|i| s.contains(&format!("${i}")))
222 .collect_vec(),
223 None => vec![],
224 };
225 let non_prebuilt_indices = match &self.prebuild {
226 Some(s) => (0..num_args)
227 .filter(|i| !s.contains(&format!("${i}")))
228 .collect_vec(),
229 _ => (0..num_args).collect_vec(),
230 };
231 let children_indices = match optimize_const {
232 #[allow(clippy::redundant_clone)] true => non_prebuilt_indices.clone(),
234 false => (0..num_args).collect_vec(),
235 };
236
237 fn idents(prefix: &str, indices: &[usize]) -> Vec<Ident> {
239 indices
240 .iter()
241 .map(|i| format_ident!("{prefix}{i}"))
242 .collect()
243 }
244 let inputs = idents("i", &children_indices);
245 let prebuilt_inputs = idents("i", &prebuilt_indices);
246 let non_prebuilt_inputs = idents("i", &non_prebuilt_indices);
247 let array_refs = idents("array", &children_indices);
248 let arrays = idents("a", &children_indices);
249 let datums = idents("v", &children_indices);
250 let arg_arrays = children_indices
251 .iter()
252 .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
253 let arg_types = children_indices.iter().map(|i| {
254 types::ref_type(&self.args[*i])
255 .parse::<TokenStream2>()
256 .unwrap()
257 });
258 let annotation: TokenStream2 = match user_fn.core_return_type.as_str() {
259 "T" | "T1" | "T2" | "T3" => format!(": Option<{}>", types::owned_type(&self.ret))
261 .parse()
262 .unwrap(),
263 _ => quote! {},
264 };
265 let ret_array_type = format_ident!("{}", types::array_type(&self.ret));
266 let builder_type = format_ident!("{}Builder", types::array_type(&self.ret));
267 let prebuilt_arg_type = match &self.prebuild {
268 Some(s) if optimize_const => s.split("::").next().unwrap().parse().unwrap(),
269 _ => quote! { () },
270 };
271 let prebuilt_arg_value = match &self.prebuild {
272 Some(s) => s
276 .replace('$', "i")
277 .parse()
278 .expect("invalid prebuild syntax"),
279 None => quote! { () },
280 };
281 let prebuild_const = if self.prebuild.is_some() && optimize_const {
282 let build_general = self.generate_build_scalar_function(user_fn, false)?;
283 quote! {{
284 let build_general = #build_general;
285 #(
286 let #prebuilt_inputs = match children[#prebuilt_indices].eval_const() {
288 Ok(s) => s,
289 Err(_) => return build_general(return_type, children),
291 };
292 let #prebuilt_inputs = match &#prebuilt_inputs {
294 Some(s) => s.as_scalar_ref_impl().try_into()?,
295 None => return Ok(Box::new(risingwave_expr::expr::LiteralExpression::new(
297 return_type,
298 None,
299 ))),
300 };
301 )*
302 #prebuilt_arg_value
303 }}
304 } else {
305 quote! { () }
306 };
307
308 let check_children = match variadic {
310 true => quote! { risingwave_expr::ensure!(children.len() >= #num_args); },
311 false => quote! { risingwave_expr::ensure!(children.len() == #num_args); },
312 };
313
314 let eval_variadic = variadic.then(|| {
316 quote! {
317 let mut columns = Vec::with_capacity(self.children.len() - #num_args);
318 for child in &self.children[#num_args..] {
319 columns.push(child.eval(input).await?);
320 }
321 let variadic_input = DataChunk::new(columns, input.visibility().clone());
322 }
323 });
324 let eval_row_variadic = variadic.then(|| {
326 quote! {
327 let mut row = Vec::with_capacity(self.children.len() - #num_args);
328 for child in &self.children[#num_args..] {
329 row.push(child.eval_row(input).await?);
330 }
331 let variadic_row = OwnedRow::new(row);
332 }
333 });
334
335 let generic = (self.ret == "boolean" && user_fn.generic == 3).then(|| {
336 let compatible_type = types::ref_type(types::min_compatible_type(&self.args))
338 .parse::<TokenStream2>()
339 .unwrap();
340 quote! { ::<_, _, #compatible_type> }
341 });
342 let prebuilt_arg = match (&self.prebuild, optimize_const) {
343 (Some(_), true) => quote! { &self.prebuilt_arg, },
345 (Some(_), false) => quote! { &#prebuilt_arg_value, },
347 (None, _) => quote! {},
349 };
350 let variadic_args = variadic.then(|| quote! { &variadic_row, });
351 let context = user_fn.context.then(|| quote! { &self.context, });
352 let writer = user_fn.write.then(|| quote! { &mut writer, });
353 let await_ = user_fn.async_.then(|| quote! { .await });
354
355 let record_error = {
356 #[allow(clippy::disallowed_methods)] let inputs_args = inputs
359 .iter()
360 .zip(user_fn.args_option.iter())
361 .map(|(input, opt)| {
362 if *opt {
363 quote! { #input.map(|s| ScalarRefImpl::from(s)) }
364 } else {
365 quote! { Some(ScalarRefImpl::from(#input)) }
366 }
367 });
368 let inputs_args = quote! {
369 let args: &[DatumRef<'_>] = &[#(#inputs_args),*];
370 let args = args.iter().copied();
371 };
372 let var_args = variadic.then(|| {
373 quote! {
374 let args = args.chain(variadic_row.iter());
375 }
376 });
377
378 quote! {
379 #inputs_args
380 #var_args
381 errors.push(ExprError::function(
382 stringify!(#fn_name),
383 args,
384 e,
385 ));
386 }
387 };
388
389 let mut output = quote! { #fn_name #generic(
392 #(#non_prebuilt_inputs,)*
393 #variadic_args
394 #prebuilt_arg
395 #context
396 #writer
397 ) #await_ };
398 output = match user_fn.return_type_kind {
401 _ if self.ret == "void" => quote! { { #output; Option::<i32>::None } },
403 ReturnTypeKind::T => quote! { Some(#output) },
404 ReturnTypeKind::Option => output,
405 ReturnTypeKind::Result => quote! {
406 match #output {
407 Ok(x) => Some(x),
408 Err(e) => {
409 #record_error
410 None
411 }
412 }
413 },
414 ReturnTypeKind::ResultOption => quote! {
415 match #output {
416 Ok(x) => x,
417 Err(e) => {
418 #record_error
419 None
420 }
421 }
422 },
423 };
424 if self.prebuild.is_some() {
427 output = quote! {
428 match (#(#inputs,)*) {
429 (#(Some(#inputs),)*) => #output,
430 _ => None,
431 }
432 };
433 } else {
434 #[allow(clippy::disallowed_methods)] let some_inputs = inputs
436 .iter()
437 .zip(user_fn.args_option.iter())
438 .map(|(input, opt)| {
439 if *opt {
440 quote! { #input }
441 } else {
442 quote! { Some(#input) }
443 }
444 });
445 output = quote! {
446 match (#(#inputs,)*) {
447 (#(#some_inputs,)*) => #output,
448 _ => None,
449 }
450 };
451 };
452 let append_output = match user_fn.write {
454 true => quote! {{
455 let mut writer = builder.writer().begin();
456 if #output.is_some() {
457 writer.finish();
458 } else {
459 drop(writer);
460 builder.append_null();
461 }
462 }},
463 false if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
464 builder.append(#output.as_ref().map(|s| s.as_ref()));
465 },
466 false => quote! {
467 let output #annotation = #output;
468 builder.append(output.as_ref().map(|s| s.as_scalar_ref()));
469 },
470 };
471 let row_output = match user_fn.write {
473 true => quote! {{
474 let mut writer = String::new();
475 #output.map(|_| writer.into())
476 }},
477 false if user_fn.core_return_type == "impl AsRef < [u8] >" => quote! {
478 #output.map(|s| s.as_ref().into())
479 },
480 false => quote! {{
481 let output #annotation = #output;
482 output.map(|s| s.into())
483 }},
484 };
485 let eval = if let Some(batch_fn) = &self.batch_fn {
487 assert!(
488 !variadic,
489 "customized batch function is not supported for variadic functions"
490 );
491 let fn_name = format_ident!("{}", batch_fn);
493 quote! {
494 let c = #fn_name(#(#arrays),*);
495 Arc::new(c.into())
496 }
497 } else if (types::is_primitive(&self.ret) || self.ret == "boolean")
498 && user_fn.is_pure()
499 && !variadic
500 && self.prebuild.is_none()
501 {
502 match self.args.len() {
504 0 => quote! {
505 let c = #ret_array_type::from_iter_bitmap(
506 std::iter::repeat_with(|| #fn_name()).take(input.capacity())
507 Bitmap::ones(input.capacity()),
508 );
509 Arc::new(c.into())
510 },
511 1 => quote! {
512 let c = #ret_array_type::from_iter_bitmap(
513 a0.raw_iter().map(|a| #fn_name(a)),
514 a0.null_bitmap().clone()
515 );
516 Arc::new(c.into())
517 },
518 2 => quote! {
519 #[allow(clippy::disallowed_methods)]
521 let c = #ret_array_type::from_iter_bitmap(
522 a0.raw_iter()
523 .zip(a1.raw_iter())
524 .map(|(a, b)| #fn_name #generic(a, b)),
525 a0.null_bitmap() & a1.null_bitmap(),
526 );
527 Arc::new(c.into())
528 },
529 n => todo!("SIMD optimization for {n} arguments"),
530 }
531 } else {
532 let let_variadic = variadic.then(|| {
534 quote! {
535 let variadic_row = variadic_input.row_at_unchecked_vis(i);
536 }
537 });
538 quote! {
539 let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone());
540
541 if input.is_compacted() {
542 for i in 0..input.capacity() {
543 #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
544 #let_variadic
545 #append_output
546 }
547 } else {
548 for i in 0..input.capacity() {
549 if unsafe { !input.visibility().is_set_unchecked(i) } {
550 builder.append_null();
551 continue;
552 }
553 #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
554 #let_variadic
555 #append_output
556 }
557 }
558 Arc::new(builder.finish().into())
559 }
560 };
561
562 Ok(quote! {
563 |return_type: DataType, children: Vec<risingwave_expr::expr::BoxedExpression>|
564 -> risingwave_expr::Result<risingwave_expr::expr::BoxedExpression>
565 {
566 use std::sync::Arc;
567 use risingwave_common::array::*;
568 use risingwave_common::types::*;
569 use risingwave_common::bitmap::Bitmap;
570 use risingwave_common::row::OwnedRow;
571 use risingwave_common::util::iter_util::ZipEqFast;
572
573 use risingwave_expr::expr::{Context, BoxedExpression};
574 use risingwave_expr::{ExprError, Result};
575 use risingwave_expr::codegen::*;
576
577 #check_children
578 let prebuilt_arg = #prebuild_const;
579 let context = Context {
580 return_type,
581 arg_types: children.iter().map(|c| c.return_type()).collect(),
582 variadic: #variadic,
583 };
584
585 #[derive(Debug)]
586 struct #struct_name {
587 context: Context,
588 children: Vec<BoxedExpression>,
589 prebuilt_arg: #prebuilt_arg_type,
590 }
591 #[async_trait]
592 impl risingwave_expr::expr::Expression for #struct_name {
593 fn return_type(&self) -> DataType {
594 self.context.return_type.clone()
595 }
596 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
597 #(
598 let #array_refs = self.children[#children_indices].eval(input).await?;
599 let #arrays: &#arg_arrays = #array_refs.as_ref().into();
600 )*
601 #eval_variadic
602 let mut errors = vec![];
603 let array = { #eval };
604 if errors.is_empty() {
605 Ok(array)
606 } else {
607 Err(ExprError::Multiple(array, errors.into()))
608 }
609 }
610 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
611 #(
612 let #datums = self.children[#children_indices].eval_row(input).await?;
613 let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap());
614 )*
615 #eval_row_variadic
616 let mut errors: Vec<ExprError> = vec![];
617 let output = #row_output;
618 if let Some(err) = errors.into_iter().next() {
619 Err(err.into())
620 } else {
621 Ok(output)
622 }
623 }
624 }
625
626 Ok(Box::new(#struct_name {
627 context,
628 children,
629 prebuilt_arg,
630 }))
631 }
632 })
633 }
634
635 pub fn generate_aggregate_descriptor(
641 &self,
642 user_fn: &AggregateFnOrImpl,
643 build_fn: bool,
644 ) -> Result<TokenStream2> {
645 let name = self.name.clone();
646
647 let mut args = Vec::with_capacity(self.args.len());
648 for ty in &self.args {
649 args.push(sig_data_type(ty));
650 }
651 let ret = sig_data_type(&self.ret);
652 let state_type = match &self.state {
653 Some(ty) if ty != "ref" => {
654 let ty = data_type(ty);
655 quote! { Some(#ty) }
656 }
657 _ => quote! { None },
658 };
659 let append_only = match build_fn {
660 false => !user_fn.has_retract(),
661 true => self.append_only,
662 };
663
664 let pb_kind = format_ident!("{}", utils::to_camel_case(&name));
665 let ctor_name = match append_only {
666 false => format_ident!("{}", self.ident_name()),
667 true => format_ident!("{}_append_only", self.ident_name()),
668 };
669 let build_fn = if build_fn {
670 let name = format_ident!("{}", user_fn.as_fn().name);
671 quote! { #name }
672 } else if self.rewritten {
673 quote! { |_| Err(ExprError::UnsupportedFunction(#name.into())) }
674 } else {
675 self.generate_agg_build_fn(user_fn)?
676 };
677 let build_retractable = match append_only {
678 true => quote! { None },
679 false => quote! { Some(#build_fn) },
680 };
681 let build_append_only = match append_only {
682 false => quote! { None },
683 true => quote! { Some(#build_fn) },
684 };
685 let retractable_state_type = match append_only {
686 true => quote! { None },
687 false => state_type.clone(),
688 };
689 let append_only_state_type = match append_only {
690 false => quote! { None },
691 true => state_type,
692 };
693 let type_infer_fn = self.generate_type_infer_fn()?;
694 let deprecated = self.deprecated;
695
696 Ok(quote! {
697 #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
698 fn #ctor_name() -> risingwave_expr::sig::FuncSign {
699 use risingwave_common::types::{DataType, DataTypeName};
700 use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
701
702 FuncSign {
703 name: risingwave_pb::expr::agg_call::PbKind::#pb_kind.into(),
704 inputs_type: vec![#(#args),*],
705 variadic: false,
706 ret_type: #ret,
707 build: FuncBuilder::Aggregate {
708 retractable: #build_retractable,
709 append_only: #build_append_only,
710 retractable_state_type: #retractable_state_type,
711 append_only_state_type: #append_only_state_type,
712 },
713 type_infer: #type_infer_fn,
714 deprecated: #deprecated,
715 }
716 }
717 })
718 }
719
720 fn generate_agg_build_fn(&self, user_fn: &AggregateFnOrImpl) -> Result<TokenStream2> {
722 let custom_state = user_fn.accumulate().first_mut_ref_arg.as_ref();
725 let state_type: TokenStream2 = match (custom_state, &self.state) {
726 (Some(s), _) => s.parse().unwrap(),
727 (_, Some(state)) if state == "ref" => types::ref_type(&self.ret).parse().unwrap(),
728 (_, Some(state)) if state != "ref" => types::owned_type(state).parse().unwrap(),
729 _ => types::owned_type(&self.ret).parse().unwrap(),
730 };
731 let let_arrays = self
732 .args
733 .iter()
734 .enumerate()
735 .map(|(i, arg)| {
736 let array = format_ident!("a{i}");
737 let array_type: TokenStream2 = types::array_type(arg).parse().unwrap();
738 quote! {
739 let #array: &#array_type = input.column_at(#i).as_ref().into();
740 }
741 })
742 .collect_vec();
743 let let_values = (0..self.args.len())
744 .map(|i| {
745 let v = format_ident!("v{i}");
746 let a = format_ident!("a{i}");
747 quote! { let #v = unsafe { #a.value_at_unchecked(row_id) }; }
748 })
749 .collect_vec();
750 let downcast_state = if custom_state.is_some() {
751 quote! { let mut state: &mut #state_type = state0.downcast_mut(); }
752 } else if let Some(s) = &self.state
753 && s == "ref"
754 {
755 quote! { let mut state: Option<#state_type> = state0.as_datum_mut().as_ref().map(|x| x.as_scalar_ref_impl().try_into().unwrap()); }
756 } else {
757 quote! { let mut state: Option<#state_type> = state0.as_datum_mut().take().map(|s| s.try_into().unwrap()); }
758 };
759 let restore_state = if custom_state.is_some() {
760 quote! {}
761 } else if let Some(s) = &self.state
762 && s == "ref"
763 {
764 quote! { *state0.as_datum_mut() = state.map(|x| x.to_owned_scalar().into()); }
765 } else {
766 quote! { *state0.as_datum_mut() = state.map(|s| s.into()); }
767 };
768 let create_state = if custom_state.is_some() {
769 quote! {
770 fn create_state(&self) -> Result<AggregateState> {
771 Ok(AggregateState::Any(Box::<#state_type>::default()))
772 }
773 }
774 } else if let Some(state) = &self.init_state {
775 let state: TokenStream2 = state.parse().unwrap();
776 quote! {
777 fn create_state(&self) -> Result<AggregateState> {
778 Ok(AggregateState::Datum(Some(#state.into())))
779 }
780 }
781 } else {
782 quote! {}
784 };
785 let args = (0..self.args.len()).map(|i| format_ident!("v{i}"));
786 let args = quote! { #(#args,)* };
787 let panic_on_retract = {
788 let msg = format!(
789 "attempt to retract on aggregate function {}, but it is append-only",
790 self.name
791 );
792 quote! { assert_eq!(op, Op::Insert, #msg); }
793 };
794 let mut next_state = match user_fn {
795 AggregateFnOrImpl::Fn(f) => {
796 let context = f.context.then(|| quote! { &self.context, });
797 let fn_name = format_ident!("{}", f.name);
798 match f.retract {
799 true => {
800 quote! { #fn_name(state, #args matches!(op, Op::Delete | Op::UpdateDelete) #context) }
801 }
802 false => quote! {{
803 #panic_on_retract
804 #fn_name(state, #args #context)
805 }},
806 }
807 }
808 AggregateFnOrImpl::Impl(i) => {
809 let retract = match i.retract {
810 Some(_) => quote! { self.function.retract(state, #args) },
811 None => panic_on_retract,
812 };
813 quote! {
814 if matches!(op, Op::Delete | Op::UpdateDelete) {
815 #retract
816 } else {
817 self.function.accumulate(state, #args)
818 }
819 }
820 }
821 };
822 next_state = match user_fn.accumulate().return_type_kind {
823 ReturnTypeKind::T => quote! { Some(#next_state) },
824 ReturnTypeKind::Option => next_state,
825 ReturnTypeKind::Result => quote! { Some(#next_state?) },
826 ReturnTypeKind::ResultOption => quote! { #next_state? },
827 };
828 if user_fn.accumulate().args_option.iter().all(|b| !b) {
829 match self.args.len() {
830 0 => {
831 next_state = quote! {
832 match state {
833 Some(state) => #next_state,
834 None => state,
835 }
836 };
837 }
838 1 => {
839 let first_state = if self.init_state.is_some() {
840 quote! { unreachable!() }
842 } else if let Some(s) = &self.state
843 && s == "ref"
844 {
845 quote! { Some(v0) }
847 } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
848 && impl_.create_state.is_some()
849 {
850 quote! {{
852 let state = self.function.create_state();
853 #next_state
854 }}
855 } else {
856 quote! {{
857 let state = #state_type::default();
858 #next_state
859 }}
860 };
861 next_state = quote! {
862 match (state, v0) {
863 (Some(state), Some(v0)) => #next_state,
864 (None, Some(v0)) => #first_state,
865 (state, None) => state,
866 }
867 };
868 }
869 _ => todo!("multiple arguments are not supported for non-option function"),
870 }
871 }
872 let update_state = if custom_state.is_some() {
873 quote! { _ = #next_state; }
874 } else {
875 quote! { state = #next_state; }
876 };
877 let get_result = if custom_state.is_some() {
878 quote! { Ok(state.downcast_ref::<#state_type>().into()) }
879 } else if let AggregateFnOrImpl::Impl(impl_) = user_fn
880 && impl_.finalize.is_some()
881 {
882 quote! {
883 let state = match state.as_datum() {
884 Some(s) => s.as_scalar_ref_impl().try_into().unwrap(),
885 None => return Ok(None),
886 };
887 Ok(Some(self.function.finalize(state).into()))
888 }
889 } else {
890 quote! { Ok(state.as_datum().clone()) }
891 };
892 let function_field = match user_fn {
893 AggregateFnOrImpl::Fn(_) => quote! {},
894 AggregateFnOrImpl::Impl(i) => {
895 let struct_name = format_ident!("{}", i.struct_name);
896 let generic = self.generic.as_ref().map(|g| {
897 let g = format_ident!("{g}");
898 quote! { <#g> }
899 });
900 quote! { function: #struct_name #generic, }
901 }
902 };
903 let function_new = match user_fn {
904 AggregateFnOrImpl::Fn(_) => quote! {},
905 AggregateFnOrImpl::Impl(i) => {
906 let struct_name = format_ident!("{}", i.struct_name);
907 let generic = self.generic.as_ref().map(|g| {
908 let g = format_ident!("{g}");
909 quote! { ::<#g> }
910 });
911 quote! { function: #struct_name #generic :: default(), }
912 }
913 };
914
915 Ok(quote! {
916 |agg| {
917 use std::collections::HashSet;
918 use std::ops::Range;
919 use risingwave_common::array::*;
920 use risingwave_common::types::*;
921 use risingwave_common::bail;
922 use risingwave_common::bitmap::Bitmap;
923 use risingwave_common_estimate_size::EstimateSize;
924
925 use risingwave_expr::expr::Context;
926 use risingwave_expr::Result;
927 use risingwave_expr::aggregate::AggregateState;
928 use risingwave_expr::codegen::async_trait;
929
930 let context = Context {
931 return_type: agg.return_type.clone(),
932 arg_types: agg.args.arg_types().to_owned(),
933 variadic: false,
934 };
935
936 struct Agg {
937 context: Context,
938 #function_field
939 }
940
941 #[async_trait]
942 impl risingwave_expr::aggregate::AggregateFunction for Agg {
943 fn return_type(&self) -> DataType {
944 self.context.return_type.clone()
945 }
946
947 #create_state
948
949 async fn update(&self, state0: &mut AggregateState, input: &StreamChunk) -> Result<()> {
950 #(#let_arrays)*
951 #downcast_state
952 for row_id in input.visibility().iter_ones() {
953 let op = unsafe { *input.ops().get_unchecked(row_id) };
954 #(#let_values)*
955 #update_state
956 }
957 #restore_state
958 Ok(())
959 }
960
961 async fn update_range(&self, state0: &mut AggregateState, input: &StreamChunk, range: Range<usize>) -> Result<()> {
962 assert!(range.end <= input.capacity());
963 #(#let_arrays)*
964 #downcast_state
965 if input.is_compacted() {
966 for row_id in range {
967 let op = unsafe { *input.ops().get_unchecked(row_id) };
968 #(#let_values)*
969 #update_state
970 }
971 } else {
972 for row_id in input.visibility().iter_ones() {
973 if row_id < range.start {
974 continue;
975 } else if row_id >= range.end {
976 break;
977 }
978 let op = unsafe { *input.ops().get_unchecked(row_id) };
979 #(#let_values)*
980 #update_state
981 }
982 }
983 #restore_state
984 Ok(())
985 }
986
987 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
988 #get_result
989 }
990 }
991
992 Ok(Box::new(Agg {
993 context,
994 #function_new
995 }))
996 }
997 })
998 }
999
1000 fn generate_table_function_descriptor(
1004 &self,
1005 user_fn: &UserFunctionAttr,
1006 build_fn: bool,
1007 ) -> Result<TokenStream2> {
1008 let name = self.name.clone();
1009 let mut args = Vec::with_capacity(self.args.len());
1010 for ty in &self.args {
1011 args.push(sig_data_type(ty));
1012 }
1013 let ret = sig_data_type(&self.ret);
1014
1015 let pb_type = format_ident!("{}", utils::to_camel_case(&name));
1016 let ctor_name = format_ident!("{}", self.ident_name());
1017 let build_fn = if build_fn {
1018 let name = format_ident!("{}", user_fn.name);
1019 quote! { #name }
1020 } else if self.rewritten {
1021 quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) }
1022 } else {
1023 self.generate_build_table_function(user_fn)?
1024 };
1025 let type_infer_fn = self.generate_type_infer_fn()?;
1026 let deprecated = self.deprecated;
1027
1028 Ok(quote! {
1029 #[risingwave_expr::codegen::linkme::distributed_slice(risingwave_expr::sig::FUNCTIONS)]
1030 fn #ctor_name() -> risingwave_expr::sig::FuncSign {
1031 use risingwave_common::types::{DataType, DataTypeName};
1032 use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder};
1033
1034 FuncSign {
1035 name: risingwave_pb::expr::table_function::Type::#pb_type.into(),
1036 inputs_type: vec![#(#args),*],
1037 variadic: false,
1038 ret_type: #ret,
1039 build: FuncBuilder::Table(#build_fn),
1040 type_infer: #type_infer_fn,
1041 deprecated: #deprecated,
1042 }
1043 }
1044 })
1045 }
1046
1047 fn generate_build_table_function(&self, user_fn: &UserFunctionAttr) -> Result<TokenStream2> {
1048 let num_args = self.args.len();
1049 let return_types = output_types(&self.ret);
1050 let fn_name = format_ident!("{}", user_fn.name);
1051 let struct_name = format_ident!("{}", utils::to_camel_case(&self.ident_name()));
1052 let arg_ids = (0..num_args)
1053 .filter(|i| match &self.prebuild {
1054 Some(s) => !s.contains(&format!("${i}")),
1055 None => true,
1056 })
1057 .collect_vec();
1058 let const_ids = (0..num_args).filter(|i| match &self.prebuild {
1059 Some(s) => s.contains(&format!("${i}")),
1060 None => false,
1061 });
1062 let inputs: Vec<_> = arg_ids.iter().map(|i| format_ident!("i{i}")).collect();
1063 let all_child: Vec<_> = (0..num_args).map(|i| format_ident!("child{i}")).collect();
1064 let const_child: Vec<_> = const_ids.map(|i| format_ident!("child{i}")).collect();
1065 let child: Vec<_> = arg_ids.iter().map(|i| format_ident!("child{i}")).collect();
1066 let array_refs: Vec<_> = arg_ids.iter().map(|i| format_ident!("array{i}")).collect();
1067 let arrays: Vec<_> = arg_ids.iter().map(|i| format_ident!("a{i}")).collect();
1068 let arg_arrays = arg_ids
1069 .iter()
1070 .map(|i| format_ident!("{}", types::array_type(&self.args[*i])));
1071 let outputs = (0..return_types.len())
1072 .map(|i| format_ident!("o{i}"))
1073 .collect_vec();
1074 let builders = (0..return_types.len())
1075 .map(|i| format_ident!("builder{i}"))
1076 .collect_vec();
1077 let builder_types = return_types
1078 .iter()
1079 .map(|ty| format_ident!("{}Builder", types::array_type(ty)))
1080 .collect_vec();
1081 let return_types = if return_types.len() == 1 {
1082 vec![quote! { self.context.return_type.clone() }]
1083 } else {
1084 (0..return_types.len())
1085 .map(|i| quote! { self.context.return_type.as_struct().types().nth(#i).unwrap().clone() })
1086 .collect()
1087 };
1088 #[allow(clippy::disallowed_methods)]
1089 let optioned_outputs = user_fn
1090 .core_return_type
1091 .split(',')
1092 .map(|t| t.contains("Option"))
1093 .zip(&outputs)
1095 .map(|(optional, o)| match optional {
1096 false => quote! { Some(#o.as_scalar_ref()) },
1097 true => quote! { #o.map(|o| o.as_scalar_ref()) },
1098 })
1099 .collect_vec();
1100 let build_value_array = if return_types.len() == 1 {
1101 quote! { let [value_array] = value_arrays; }
1102 } else {
1103 quote! {
1104 let value_array = StructArray::new(
1105 self.context.return_type.as_struct().clone(),
1106 value_arrays.to_vec(),
1107 Bitmap::ones(len),
1108 ).into_ref();
1109 }
1110 };
1111 let context = user_fn.context.then(|| quote! { &self.context, });
1112 let prebuilt_arg = match &self.prebuild {
1113 Some(_) => quote! { &self.prebuilt_arg, },
1114 None => quote! {},
1115 };
1116 let prebuilt_arg_type = match &self.prebuild {
1117 Some(s) => s.split("::").next().unwrap().parse().unwrap(),
1118 None => quote! { () },
1119 };
1120 let prebuilt_arg_value = match &self.prebuild {
1121 Some(s) => s
1122 .replace('$', "child")
1123 .parse()
1124 .expect("invalid prebuild syntax"),
1125 None => quote! { () },
1126 };
1127 let iter = quote! { #fn_name(#(#inputs,)* #prebuilt_arg #context) };
1128 let mut iter = match user_fn.return_type_kind {
1129 ReturnTypeKind::T => quote! { #iter },
1130 ReturnTypeKind::Option => quote! { match #iter {
1131 Some(it) => it,
1132 None => continue,
1133 } },
1134 ReturnTypeKind::Result => quote! { match #iter {
1135 Ok(it) => it,
1136 Err(e) => {
1137 index_builder.append(Some(i as i32));
1138 #(#builders.append_null();)*
1139 error_builder.append_display(Some(e.as_report()));
1140 continue;
1141 }
1142 } },
1143 ReturnTypeKind::ResultOption => quote! { match #iter {
1144 Ok(Some(it)) => it,
1145 Ok(None) => continue,
1146 Err(e) => {
1147 index_builder.append(Some(i as i32));
1148 #(#builders.append_null();)*
1149 error_builder.append_display(Some(e.as_report()));
1150 continue;
1151 }
1152 } },
1153 };
1154 #[allow(clippy::disallowed_methods)] let some_inputs = inputs
1158 .iter()
1159 .zip(user_fn.args_option.iter())
1160 .map(|(input, opt)| {
1161 if *opt {
1162 quote! { #input }
1163 } else {
1164 quote! { Some(#input) }
1165 }
1166 });
1167 iter = quote! {
1168 match (#(#inputs,)*) {
1169 (#(#some_inputs,)*) => #iter,
1170 _ => continue,
1171 }
1172 };
1173 let iterator_item_type = user_fn.iterator_item_kind.clone().ok_or_else(|| {
1174 Error::new(
1175 user_fn.return_type_span,
1176 "expect `impl Iterator` in return type",
1177 )
1178 })?;
1179 let append_output = match iterator_item_type {
1180 ReturnTypeKind::T => quote! {
1181 let (#(#outputs),*) = output;
1182 #(#builders.append(#optioned_outputs);)* error_builder.append_null();
1183 },
1184 ReturnTypeKind::Option => quote! { match output {
1185 Some((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1186 None => { #(#builders.append_null();)* error_builder.append_null(); }
1187 } },
1188 ReturnTypeKind::Result => quote! { match output {
1189 Ok((#(#outputs),*)) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1190 Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1191 } },
1192 ReturnTypeKind::ResultOption => quote! { match output {
1193 Ok(Some((#(#outputs),*))) => { #(#builders.append(#optioned_outputs);)* error_builder.append_null(); }
1194 Ok(None) => { #(#builders.append_null();)* error_builder.append_null(); }
1195 Err(e) => { #(#builders.append_null();)* error_builder.append_display(Some(e.as_report())); }
1196 } },
1197 };
1198
1199 Ok(quote! {
1200 |return_type, chunk_size, children| {
1201 use risingwave_common::array::*;
1202 use risingwave_common::types::*;
1203 use risingwave_common::bitmap::Bitmap;
1204 use risingwave_common::util::iter_util::ZipEqFast;
1205 use risingwave_expr::expr::{BoxedExpression, Context};
1206 use risingwave_expr::{Result, ExprError};
1207 use risingwave_expr::codegen::*;
1208
1209 risingwave_expr::ensure!(children.len() == #num_args);
1210
1211 let context = Context {
1212 return_type: return_type.clone(),
1213 arg_types: children.iter().map(|c| c.return_type()).collect(),
1214 variadic: false,
1215 };
1216
1217 let mut iter = children.into_iter();
1218 #(let #all_child = iter.next().unwrap();)*
1219 #(
1220 let #const_child = #const_child.eval_const()?;
1221 let #const_child = match &#const_child {
1222 Some(s) => s.as_scalar_ref_impl().try_into()?,
1223 None => return Ok(risingwave_expr::table_function::empty(return_type)),
1225 };
1226 )*
1227
1228 #[derive(Debug)]
1229 struct #struct_name {
1230 context: Context,
1231 chunk_size: usize,
1232 #(#child: BoxedExpression,)*
1233 prebuilt_arg: #prebuilt_arg_type,
1234 }
1235 #[async_trait]
1236 impl risingwave_expr::table_function::TableFunction for #struct_name {
1237 fn return_type(&self) -> DataType {
1238 self.context.return_type.clone()
1239 }
1240 async fn eval<'a>(&'a self, input: &'a DataChunk) -> BoxStream<'a, Result<DataChunk>> {
1241 self.eval_inner(input)
1242 }
1243 }
1244 impl #struct_name {
1245 #[try_stream(boxed, ok = DataChunk, error = ExprError)]
1246 async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
1247 #(
1248 let #array_refs = self.#child.eval(input).await?;
1249 let #arrays: &#arg_arrays = #array_refs.as_ref().into();
1250 )*
1251
1252 let mut index_builder = I32ArrayBuilder::new(self.chunk_size);
1253 #(let mut #builders = #builder_types::with_type(self.chunk_size, #return_types);)*
1254 let mut error_builder = Utf8ArrayBuilder::new(self.chunk_size);
1255
1256 for i in 0..input.capacity() {
1257 if unsafe { !input.visibility().is_set_unchecked(i) } {
1258 continue;
1259 }
1260 #(let #inputs = unsafe { #arrays.value_at_unchecked(i) };)*
1261 for output in #iter {
1262 index_builder.append(Some(i as i32));
1263 #append_output
1264
1265 if index_builder.len() == self.chunk_size {
1266 let len = index_builder.len();
1267 let index_array = std::mem::replace(&mut index_builder, I32ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1268 let value_arrays = [#(std::mem::replace(&mut #builders, #builder_types::with_type(self.chunk_size, #return_types)).finish().into_ref()),*];
1269 #build_value_array
1270 let error_array = std::mem::replace(&mut error_builder, Utf8ArrayBuilder::new(self.chunk_size)).finish().into_ref();
1271 if error_array.null_bitmap().any() {
1272 yield DataChunk::new(vec![index_array, value_array, error_array], self.chunk_size);
1273 } else {
1274 yield DataChunk::new(vec![index_array, value_array], self.chunk_size);
1275 }
1276 }
1277 }
1278 }
1279
1280 if index_builder.len() > 0 {
1281 let len = index_builder.len();
1282 let index_array = index_builder.finish().into_ref();
1283 let value_arrays = [#(#builders.finish().into_ref()),*];
1284 #build_value_array
1285 let error_array = error_builder.finish().into_ref();
1286 if error_array.null_bitmap().any() {
1287 yield DataChunk::new(vec![index_array, value_array, error_array], len);
1288 } else {
1289 yield DataChunk::new(vec![index_array, value_array], len);
1290 }
1291 }
1292 }
1293 }
1294
1295 Ok(Box::new(#struct_name {
1296 context,
1297 chunk_size,
1298 #(#child,)*
1299 prebuilt_arg: #prebuilt_arg_value,
1300 }))
1301 }
1302 })
1303 }
1304}
1305
1306fn sig_data_type(ty: &str) -> TokenStream2 {
1307 match ty {
1308 "any" => quote! { SigDataType::Any },
1309 "anyarray" => quote! { SigDataType::AnyArray },
1310 "anymap" => quote! { SigDataType::AnyMap },
1311 "struct" => quote! { SigDataType::AnyStruct },
1312 _ if ty.starts_with("struct") && ty.contains("any") => quote! { SigDataType::AnyStruct },
1313 _ => {
1314 let datatype = data_type(ty);
1315 quote! { SigDataType::Exact(#datatype) }
1316 }
1317 }
1318}
1319
1320fn data_type(ty: &str) -> TokenStream2 {
1321 if let Some(ty) = ty.strip_suffix("[]") {
1322 let inner_type = data_type(ty);
1323 return quote! { DataType::List(Box::new(#inner_type)) };
1324 }
1325 if ty.starts_with("struct<") {
1326 return quote! { DataType::Struct(#ty.parse().expect("invalid struct type")) };
1327 }
1328 let variant = format_ident!("{}", types::data_type(ty));
1329 quote! { DataType::#variant }
1336}
1337
1338fn output_types(ty: &str) -> Vec<&str> {
1345 if let Some(s) = ty.strip_prefix("struct<")
1346 && let Some(args) = s.strip_suffix('>')
1347 {
1348 args.split(',')
1349 .map(|s| s.split_whitespace().nth(1).unwrap())
1350 .collect()
1351 } else {
1352 vec![ty]
1353 }
1354}