risingwave_frontend/handler/
create_function.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::Context;
16use either::Either;
17use risingwave_common::catalog::FunctionId;
18use risingwave_common::types::StructType;
19use risingwave_expr::sig::{CreateOptions, UdfKind};
20use risingwave_pb::catalog::PbFunction;
21use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction};
22
23use super::*;
24use crate::{Binder, bind_data_type};
25
26pub async fn handle_create_function(
27    handler_args: HandlerArgs,
28    or_replace: bool,
29    temporary: bool,
30    if_not_exists: bool,
31    name: ObjectName,
32    args: Option<Vec<OperateFunctionArg>>,
33    returns: Option<CreateFunctionReturns>,
34    params: CreateFunctionBody,
35    with_options: CreateFunctionWithOptions,
36) -> Result<RwPgResponse> {
37    if or_replace {
38        bail_not_implemented!("CREATE OR REPLACE FUNCTION");
39    }
40    if temporary {
41        bail_not_implemented!("CREATE TEMPORARY FUNCTION");
42    }
43
44    let udf_config = handler_args.session.env().udf_config();
45
46    // e.g., `language [ python / javascript / ...etc]`
47    let language = match params.language {
48        Some(lang) => {
49            let lang = lang.real_value().to_lowercase();
50            match &*lang {
51                "java" => lang, // only support external UDF for Java
52                "python" if udf_config.enable_embedded_python_udf => lang,
53                "javascript" if udf_config.enable_embedded_javascript_udf => lang,
54                "rust" | "wasm" if udf_config.enable_embedded_wasm_udf => lang,
55                "python" | "javascript" | "rust" | "wasm" => {
56                    return Err(ErrorCode::InvalidParameterValue(format!(
57                        "{} UDF is not enabled in configuration",
58                        lang
59                    ))
60                    .into());
61                }
62                _ => {
63                    return Err(ErrorCode::InvalidParameterValue(format!(
64                        "language {} is not supported",
65                        lang
66                    ))
67                    .into());
68                }
69            }
70        }
71        // Empty language is acceptable since we only require the external server implements the
72        // correct protocol.
73        None => "".to_owned(),
74    };
75
76    let runtime = match params.runtime {
77        Some(_) => {
78            return Err(ErrorCode::InvalidParameterValue(
79                "runtime selection is currently not supported".to_owned(),
80            )
81            .into());
82        }
83        None => None,
84    };
85
86    let return_type;
87    let kind = match returns {
88        Some(CreateFunctionReturns::Value(data_type)) => {
89            return_type = bind_data_type(&data_type)?;
90            Kind::Scalar(ScalarFunction {})
91        }
92        Some(CreateFunctionReturns::Table(columns)) => {
93            if columns.len() == 1 {
94                // return type is the original type for single column
95                return_type = bind_data_type(&columns[0].data_type)?;
96            } else {
97                // return type is a struct for multiple columns
98                let it = columns
99                    .into_iter()
100                    .map(|c| bind_data_type(&c.data_type).map(|ty| (c.name.real_value(), ty)));
101                let fields = it.try_collect::<_, Vec<_>, _>()?;
102                return_type = StructType::new(fields).into();
103            }
104            Kind::Table(TableFunction {})
105        }
106        None => {
107            return Err(ErrorCode::InvalidParameterValue(
108                "return type must be specified".to_owned(),
109            )
110            .into());
111        }
112    };
113
114    let mut arg_names = vec![];
115    let mut arg_types = vec![];
116    for arg in args.unwrap_or_default() {
117        arg_names.push(arg.name.map_or("".to_owned(), |n| n.real_value()));
118        arg_types.push(bind_data_type(&arg.data_type)?);
119    }
120
121    // resolve database and schema id
122    let session = &handler_args.session;
123    let db_name = &session.database();
124    let (schema_name, function_name) =
125        Binder::resolve_schema_qualified_name(db_name, name.clone())?;
126    let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;
127
128    // check if the function exists in the catalog
129    if let Either::Right(resp) = session.check_function_name_duplicated(
130        StatementType::CREATE_FUNCTION,
131        name,
132        &arg_types,
133        if_not_exists,
134    )? {
135        return Ok(resp);
136    }
137
138    let link = match &params.using {
139        Some(CreateFunctionUsing::Link(l)) => Some(l.as_str()),
140        _ => None,
141    };
142    let base64_decoded = match &params.using {
143        Some(CreateFunctionUsing::Base64(encoded)) => {
144            use base64::prelude::{BASE64_STANDARD, Engine};
145            let bytes = BASE64_STANDARD
146                .decode(encoded)
147                .context("invalid base64 encoding")?;
148            Some(bytes)
149        }
150        _ => None,
151    };
152
153    let create_fn =
154        risingwave_expr::sig::find_udf_impl(&language, runtime.as_deref(), link)?.create_fn;
155    let output = create_fn(CreateOptions {
156        kind: match kind {
157            Kind::Scalar(_) => UdfKind::Scalar,
158            Kind::Table(_) => UdfKind::Table,
159            Kind::Aggregate(_) => unreachable!(),
160        },
161        name: &function_name,
162        arg_names: &arg_names,
163        arg_types: &arg_types,
164        return_type: &return_type,
165        as_: params.as_.as_ref().map(|s| s.as_str()),
166        using_link: link,
167        using_base64_decoded: base64_decoded.as_deref(),
168    })?;
169
170    let function = PbFunction {
171        id: FunctionId::placeholder().0,
172        schema_id,
173        database_id,
174        name: function_name,
175        kind: Some(kind),
176        arg_names,
177        arg_types: arg_types.into_iter().map(|t| t.into()).collect(),
178        return_type: Some(return_type.into()),
179        language,
180        runtime,
181        name_in_runtime: Some(output.name_in_runtime),
182        link: link.map(|s| s.to_owned()),
183        body: output.body,
184        compressed_binary: output.compressed_binary,
185        owner: session.user_id(),
186        always_retry_on_network_error: with_options
187            .always_retry_on_network_error
188            .unwrap_or_default(),
189        is_async: with_options.r#async,
190        is_batched: with_options.batch,
191    };
192
193    let catalog_writer = session.catalog_writer()?;
194    catalog_writer.create_function(function).await?;
195
196    Ok(PgResponse::empty_result(StatementType::CREATE_FUNCTION))
197}