risingwave_frontend/expr/
user_defined_function.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use itertools::Itertools;
18use risingwave_common::catalog::{FunctionId, Schema};
19use risingwave_common::types::DataType;
20
21use super::{Expr, ExprDisplay, ExprImpl};
22use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind};
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub struct UserDefinedFunction {
26    pub args: Vec<ExprImpl>,
27    pub catalog: Arc<FunctionCatalog>,
28}
29
30impl UserDefinedFunction {
31    pub fn new(catalog: Arc<FunctionCatalog>, args: Vec<ExprImpl>) -> Self {
32        Self { args, catalog }
33    }
34
35    pub(super) fn from_expr_proto(
36        udf: &risingwave_pb::expr::UserDefinedFunction,
37        return_type: DataType,
38    ) -> crate::error::Result<Self> {
39        let args: Vec<_> = udf
40            .get_children()
41            .iter()
42            .map(ExprImpl::from_expr_proto)
43            .try_collect()?;
44
45        // function catalog
46        let arg_types = udf.get_arg_types().iter().map_into().collect_vec();
47        let catalog = FunctionCatalog {
48            // FIXME(yuhao): function id is not in udf proto.
49            id: FunctionId::placeholder(),
50            name: udf.name.clone(),
51            // FIXME(yuhao): owner is not in udf proto.
52            owner: u32::MAX - 1,
53            kind: FunctionKind::Scalar,
54            arg_names: udf.arg_names.clone(),
55            arg_types,
56            return_type,
57            language: udf.language.clone(),
58            runtime: udf.runtime.clone(),
59            name_in_runtime: udf.name_in_runtime().map(|x| x.to_owned()),
60            body: udf.body.clone(),
61            link: udf.link.clone(),
62            compressed_binary: udf.compressed_binary.clone(),
63            always_retry_on_network_error: udf.always_retry_on_network_error,
64            is_batched: udf.is_batched,
65            is_async: udf.is_async,
66        };
67
68        Ok(Self {
69            args,
70            catalog: Arc::new(catalog),
71        })
72    }
73}
74
75impl Expr for UserDefinedFunction {
76    fn return_type(&self) -> DataType {
77        self.catalog.return_type.clone()
78    }
79
80    fn to_expr_proto(&self) -> risingwave_pb::expr::ExprNode {
81        use risingwave_pb::expr::expr_node::*;
82        use risingwave_pb::expr::*;
83        ExprNode {
84            function_type: Type::Unspecified.into(),
85            return_type: Some(self.return_type().to_protobuf()),
86            rex_node: Some(RexNode::Udf(Box::new(UserDefinedFunction {
87                children: self.args.iter().map(Expr::to_expr_proto).collect(),
88                name: self.catalog.name.clone(),
89                arg_names: self.catalog.arg_names.clone(),
90                arg_types: self
91                    .catalog
92                    .arg_types
93                    .iter()
94                    .map(|t| t.to_protobuf())
95                    .collect(),
96                language: self.catalog.language.clone(),
97                runtime: self.catalog.runtime.clone(),
98                identifier: self.catalog.name_in_runtime.clone(),
99                link: self.catalog.link.clone(),
100                body: self.catalog.body.clone(),
101                compressed_binary: self.catalog.compressed_binary.clone(),
102                always_retry_on_network_error: self.catalog.always_retry_on_network_error,
103                is_async: self.catalog.is_async,
104                is_batched: self.catalog.is_batched,
105                version: PbUdfExprVersion::LATEST as _,
106            }))),
107        }
108    }
109}
110
111pub struct UserDefinedFunctionDisplay<'a> {
112    pub func_call: &'a UserDefinedFunction,
113    pub input_schema: &'a Schema,
114}
115
116impl std::fmt::Debug for UserDefinedFunctionDisplay<'_> {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        let that = self.func_call;
119        let mut builder = f.debug_tuple(&that.catalog.name);
120        that.args.iter().for_each(|arg| {
121            builder.field(&ExprDisplay {
122                expr: arg,
123                input_schema: self.input_schema,
124            });
125        });
126        builder.finish()
127    }
128}