risingwave_frontend/expr/
user_defined_function.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use itertools::Itertools;
use risingwave_common::catalog::FunctionId;
use risingwave_common::types::DataType;

use super::{Expr, ExprImpl};
use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct UserDefinedFunction {
    pub args: Vec<ExprImpl>,
    pub catalog: Arc<FunctionCatalog>,
}

impl UserDefinedFunction {
    pub fn new(catalog: Arc<FunctionCatalog>, args: Vec<ExprImpl>) -> Self {
        Self { args, catalog }
    }

    pub(super) fn from_expr_proto(
        udf: &risingwave_pb::expr::UserDefinedFunction,
        return_type: DataType,
    ) -> crate::error::Result<Self> {
        let args: Vec<_> = udf
            .get_children()
            .iter()
            .map(ExprImpl::from_expr_proto)
            .try_collect()?;

        // function catalog
        let arg_types = udf.get_arg_types().iter().map_into().collect_vec();
        let catalog = FunctionCatalog {
            // FIXME(yuhao): function id is not in udf proto.
            id: FunctionId::placeholder(),
            name: udf.name.clone(),
            // FIXME(yuhao): owner is not in udf proto.
            owner: u32::MAX - 1,
            kind: FunctionKind::Scalar,
            arg_names: udf.arg_names.clone(),
            arg_types,
            return_type,
            language: udf.language.clone(),
            runtime: udf.runtime.clone(),
            identifier: udf.identifier.clone(),
            body: udf.body.clone(),
            link: udf.link.clone(),
            compressed_binary: udf.compressed_binary.clone(),
            always_retry_on_network_error: udf.always_retry_on_network_error,
        };

        Ok(Self {
            args,
            catalog: Arc::new(catalog),
        })
    }
}

impl Expr for UserDefinedFunction {
    fn return_type(&self) -> DataType {
        self.catalog.return_type.clone()
    }

    fn to_expr_proto(&self) -> risingwave_pb::expr::ExprNode {
        use risingwave_pb::expr::expr_node::*;
        use risingwave_pb::expr::*;
        ExprNode {
            function_type: Type::Unspecified.into(),
            return_type: Some(self.return_type().to_protobuf()),
            rex_node: Some(RexNode::Udf(UserDefinedFunction {
                children: self.args.iter().map(Expr::to_expr_proto).collect(),
                name: self.catalog.name.clone(),
                arg_names: self.catalog.arg_names.clone(),
                arg_types: self
                    .catalog
                    .arg_types
                    .iter()
                    .map(|t| t.to_protobuf())
                    .collect(),
                language: self.catalog.language.clone(),
                runtime: self.catalog.runtime.clone(),
                identifier: self.catalog.identifier.clone(),
                link: self.catalog.link.clone(),
                body: self.catalog.body.clone(),
                compressed_binary: self.catalog.compressed_binary.clone(),
                always_retry_on_network_error: self.catalog.always_retry_on_network_error,
            })),
        }
    }
}