risingwave_frontend/expr/
user_defined_function.rs1use std::sync::Arc;
16
17use itertools::Itertools;
18use risingwave_common::catalog::{FunctionId, Schema};
19use risingwave_common::types::DataType;
20
21use super::{Expr, ExprDisplay, ExprImpl};
22use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind};
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub struct UserDefinedFunction {
26 pub args: Vec<ExprImpl>,
27 pub catalog: Arc<FunctionCatalog>,
28}
29
30impl UserDefinedFunction {
31 pub fn new(catalog: Arc<FunctionCatalog>, args: Vec<ExprImpl>) -> Self {
32 Self { args, catalog }
33 }
34
35 pub(super) fn from_expr_proto(
36 udf: &risingwave_pb::expr::UserDefinedFunction,
37 return_type: DataType,
38 ) -> crate::error::Result<Self> {
39 let args: Vec<_> = udf
40 .get_children()
41 .iter()
42 .map(ExprImpl::from_expr_proto)
43 .try_collect()?;
44
45 let arg_types = udf.get_arg_types().iter().map_into().collect_vec();
47 let catalog = FunctionCatalog {
48 id: FunctionId::placeholder(),
50 name: udf.name.clone(),
51 owner: u32::MAX - 1,
53 kind: FunctionKind::Scalar,
54 arg_names: udf.arg_names.clone(),
55 arg_types,
56 return_type,
57 language: udf.language.clone(),
58 runtime: udf.runtime.clone(),
59 name_in_runtime: udf.name_in_runtime().map(|x| x.to_owned()),
60 body: udf.body.clone(),
61 link: udf.link.clone(),
62 compressed_binary: udf.compressed_binary.clone(),
63 always_retry_on_network_error: udf.always_retry_on_network_error,
64 is_batched: udf.is_batched,
65 is_async: udf.is_async,
66 };
67
68 Ok(Self {
69 args,
70 catalog: Arc::new(catalog),
71 })
72 }
73}
74
75impl Expr for UserDefinedFunction {
76 fn return_type(&self) -> DataType {
77 self.catalog.return_type.clone()
78 }
79
80 fn to_expr_proto(&self) -> risingwave_pb::expr::ExprNode {
81 use risingwave_pb::expr::expr_node::*;
82 use risingwave_pb::expr::*;
83 ExprNode {
84 function_type: Type::Unspecified.into(),
85 return_type: Some(self.return_type().to_protobuf()),
86 rex_node: Some(RexNode::Udf(Box::new(UserDefinedFunction {
87 children: self.args.iter().map(Expr::to_expr_proto).collect(),
88 name: self.catalog.name.clone(),
89 arg_names: self.catalog.arg_names.clone(),
90 arg_types: self
91 .catalog
92 .arg_types
93 .iter()
94 .map(|t| t.to_protobuf())
95 .collect(),
96 language: self.catalog.language.clone(),
97 runtime: self.catalog.runtime.clone(),
98 identifier: self.catalog.name_in_runtime.clone(),
99 link: self.catalog.link.clone(),
100 body: self.catalog.body.clone(),
101 compressed_binary: self.catalog.compressed_binary.clone(),
102 always_retry_on_network_error: self.catalog.always_retry_on_network_error,
103 is_async: self.catalog.is_async,
104 is_batched: self.catalog.is_batched,
105 version: PbUdfExprVersion::LATEST as _,
106 }))),
107 }
108 }
109}
110
111pub struct UserDefinedFunctionDisplay<'a> {
112 pub func_call: &'a UserDefinedFunction,
113 pub input_schema: &'a Schema,
114}
115
116impl std::fmt::Debug for UserDefinedFunctionDisplay<'_> {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 let that = self.func_call;
119 let mut builder = f.debug_tuple(&that.catalog.name);
120 that.args.iter().for_each(|arg| {
121 builder.field(&ExprDisplay {
122 expr: arg,
123 input_schema: self.input_schema,
124 });
125 });
126 builder.finish()
127 }
128}