risingwave_frontend/expr/
user_defined_function.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use itertools::Itertools;
18use risingwave_common::catalog::{FunctionId, Schema};
19use risingwave_common::types::DataType;
20
21use super::{Expr, ExprDisplay, ExprImpl};
22use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind};
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub struct UserDefinedFunction {
26    pub args: Vec<ExprImpl>,
27    pub catalog: Arc<FunctionCatalog>,
28}
29
30impl UserDefinedFunction {
31    pub fn new(catalog: Arc<FunctionCatalog>, args: Vec<ExprImpl>) -> Self {
32        Self { args, catalog }
33    }
34
35    pub(super) fn from_expr_proto(
36        udf: &risingwave_pb::expr::UserDefinedFunction,
37        return_type: DataType,
38    ) -> crate::error::Result<Self> {
39        let args: Vec<_> = udf
40            .get_children()
41            .iter()
42            .map(ExprImpl::from_expr_proto)
43            .try_collect()?;
44
45        // function catalog
46        let arg_types = udf.get_arg_types().iter().map_into().collect_vec();
47        let catalog = FunctionCatalog {
48            // FIXME(yuhao): function id is not in udf proto.
49            id: FunctionId::placeholder(),
50            name: udf.name.clone(),
51            // FIXME(yuhao): owner is not in udf proto.
52            owner: u32::MAX - 1,
53            kind: FunctionKind::Scalar,
54            arg_names: udf.arg_names.clone(),
55            arg_types,
56            return_type,
57            language: udf.language.clone(),
58            runtime: udf.runtime.clone(),
59            name_in_runtime: udf.name_in_runtime().map(|x| x.to_owned()),
60            body: udf.body.clone(),
61            link: udf.link.clone(),
62            compressed_binary: udf.compressed_binary.clone(),
63            always_retry_on_network_error: udf.always_retry_on_network_error,
64            is_batched: udf.is_batched,
65            is_async: udf.is_async,
66            created_at_epoch: None,
67            created_at_cluster_version: None,
68        };
69
70        Ok(Self {
71            args,
72            catalog: Arc::new(catalog),
73        })
74    }
75}
76
77impl Expr for UserDefinedFunction {
78    fn return_type(&self) -> DataType {
79        self.catalog.return_type.clone()
80    }
81
82    fn try_to_expr_proto(&self) -> Result<risingwave_pb::expr::ExprNode, String> {
83        use risingwave_pb::expr::expr_node::*;
84        use risingwave_pb::expr::*;
85
86        let children = self
87            .args
88            .iter()
89            .map(|arg| arg.try_to_expr_proto())
90            .try_collect()?;
91
92        Ok(ExprNode {
93            function_type: Type::Unspecified.into(),
94            return_type: Some(self.return_type().to_protobuf()),
95            rex_node: Some(RexNode::Udf(Box::new(UserDefinedFunction {
96                children,
97                name: self.catalog.name.clone(),
98                arg_names: self.catalog.arg_names.clone(),
99                arg_types: self
100                    .catalog
101                    .arg_types
102                    .iter()
103                    .map(|t| t.to_protobuf())
104                    .collect(),
105                language: self.catalog.language.clone(),
106                runtime: self.catalog.runtime.clone(),
107                identifier: self.catalog.name_in_runtime.clone(),
108                link: self.catalog.link.clone(),
109                body: self.catalog.body.clone(),
110                compressed_binary: self.catalog.compressed_binary.clone(),
111                always_retry_on_network_error: self.catalog.always_retry_on_network_error,
112                is_async: self.catalog.is_async,
113                is_batched: self.catalog.is_batched,
114                version: PbUdfExprVersion::LATEST as _,
115            }))),
116        })
117    }
118}
119
120pub struct UserDefinedFunctionDisplay<'a> {
121    pub func_call: &'a UserDefinedFunction,
122    pub input_schema: &'a Schema,
123}
124
125impl std::fmt::Debug for UserDefinedFunctionDisplay<'_> {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        let that = self.func_call;
128        let mut builder = f.debug_tuple(&that.catalog.name);
129        that.args.iter().for_each(|arg| {
130            builder.field(&ExprDisplay {
131                expr: arg,
132                input_schema: self.input_schema,
133            });
134        });
135        builder.finish()
136    }
137}