risingwave_frontend/expr/
user_defined_function.rs1use std::sync::Arc;
16
17use itertools::Itertools;
18use risingwave_common::catalog::{FunctionId, Schema};
19use risingwave_common::types::DataType;
20
21use super::{Expr, ExprDisplay, ExprImpl};
22use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind};
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub struct UserDefinedFunction {
26 pub args: Vec<ExprImpl>,
27 pub catalog: Arc<FunctionCatalog>,
28}
29
30impl UserDefinedFunction {
31 pub fn new(catalog: Arc<FunctionCatalog>, args: Vec<ExprImpl>) -> Self {
32 Self { args, catalog }
33 }
34
35 pub(super) fn from_expr_proto(
36 udf: &risingwave_pb::expr::UserDefinedFunction,
37 return_type: DataType,
38 ) -> crate::error::Result<Self> {
39 let args: Vec<_> = udf
40 .get_children()
41 .iter()
42 .map(ExprImpl::from_expr_proto)
43 .try_collect()?;
44
45 let arg_types = udf.get_arg_types().iter().map_into().collect_vec();
47 let catalog = FunctionCatalog {
48 id: FunctionId::placeholder(),
50 name: udf.name.clone(),
51 owner: u32::MAX - 1,
53 kind: FunctionKind::Scalar,
54 arg_names: udf.arg_names.clone(),
55 arg_types,
56 return_type,
57 language: udf.language.clone(),
58 runtime: udf.runtime.clone(),
59 name_in_runtime: udf.name_in_runtime().map(|x| x.to_owned()),
60 body: udf.body.clone(),
61 link: udf.link.clone(),
62 compressed_binary: udf.compressed_binary.clone(),
63 always_retry_on_network_error: udf.always_retry_on_network_error,
64 is_batched: udf.is_batched,
65 is_async: udf.is_async,
66 created_at_epoch: None,
67 created_at_cluster_version: None,
68 };
69
70 Ok(Self {
71 args,
72 catalog: Arc::new(catalog),
73 })
74 }
75}
76
77impl Expr for UserDefinedFunction {
78 fn return_type(&self) -> DataType {
79 self.catalog.return_type.clone()
80 }
81
82 fn try_to_expr_proto(&self) -> Result<risingwave_pb::expr::ExprNode, String> {
83 use risingwave_pb::expr::expr_node::*;
84 use risingwave_pb::expr::*;
85
86 let children = self
87 .args
88 .iter()
89 .map(|arg| arg.try_to_expr_proto())
90 .try_collect()?;
91
92 Ok(ExprNode {
93 function_type: Type::Unspecified.into(),
94 return_type: Some(self.return_type().to_protobuf()),
95 rex_node: Some(RexNode::Udf(Box::new(UserDefinedFunction {
96 children,
97 name: self.catalog.name.clone(),
98 arg_names: self.catalog.arg_names.clone(),
99 arg_types: self
100 .catalog
101 .arg_types
102 .iter()
103 .map(|t| t.to_protobuf())
104 .collect(),
105 language: self.catalog.language.clone(),
106 runtime: self.catalog.runtime.clone(),
107 identifier: self.catalog.name_in_runtime.clone(),
108 link: self.catalog.link.clone(),
109 body: self.catalog.body.clone(),
110 compressed_binary: self.catalog.compressed_binary.clone(),
111 always_retry_on_network_error: self.catalog.always_retry_on_network_error,
112 is_async: self.catalog.is_async,
113 is_batched: self.catalog.is_batched,
114 version: PbUdfExprVersion::LATEST as _,
115 }))),
116 })
117 }
118}
119
120pub struct UserDefinedFunctionDisplay<'a> {
121 pub func_call: &'a UserDefinedFunction,
122 pub input_schema: &'a Schema,
123}
124
125impl std::fmt::Debug for UserDefinedFunctionDisplay<'_> {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 let that = self.func_call;
128 let mut builder = f.debug_tuple(&that.catalog.name);
129 that.args.iter().for_each(|arg| {
130 builder.field(&ExprDisplay {
131 expr: arg,
132 input_schema: self.input_schema,
133 });
134 });
135 builder.finish()
136 }
137}