risingwave_frontend/handler/
create_function.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::Context;
16use either::Either;
17use risingwave_common::catalog::FunctionId;
18use risingwave_common::types::StructType;
19use risingwave_expr::sig::{CreateOptions, UdfKind};
20use risingwave_pb::catalog::PbFunction;
21use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction};
22
23use super::*;
24use crate::{Binder, bind_data_type};
25
26pub async fn handle_create_function(
27    handler_args: HandlerArgs,
28    or_replace: bool,
29    temporary: bool,
30    if_not_exists: bool,
31    name: ObjectName,
32    args: Option<Vec<OperateFunctionArg>>,
33    returns: Option<CreateFunctionReturns>,
34    params: CreateFunctionBody,
35    with_options: CreateFunctionWithOptions,
36) -> Result<RwPgResponse> {
37    if or_replace {
38        bail_not_implemented!("CREATE OR REPLACE FUNCTION");
39    }
40    if temporary {
41        bail_not_implemented!("CREATE TEMPORARY FUNCTION");
42    }
43
44    let udf_config = handler_args.session.env().udf_config();
45
46    // e.g., `language [ python / javascript / ...etc]`
47    let language = match params.language {
48        Some(lang) => {
49            let lang = lang.real_value().to_lowercase();
50            match &*lang {
51                "java" => lang, // only support external UDF for Java
52                "python" if udf_config.enable_embedded_python_udf => lang,
53                "javascript" if udf_config.enable_embedded_javascript_udf => lang,
54                "rust" | "wasm" if udf_config.enable_embedded_wasm_udf => lang,
55                "python" | "javascript" | "rust" | "wasm" => {
56                    return Err(ErrorCode::InvalidParameterValue(format!(
57                        "{} UDF is not enabled in configuration",
58                        lang
59                    ))
60                    .into());
61                }
62                _ => {
63                    return Err(ErrorCode::InvalidParameterValue(format!(
64                        "language {} is not supported",
65                        lang
66                    ))
67                    .into());
68                }
69            }
70        }
71        // Empty language is acceptable since we only require the external server implements the
72        // correct protocol.
73        None => "".to_owned(),
74    };
75
76    let runtime = match params.runtime {
77        Some(_) => {
78            return Err(ErrorCode::InvalidParameterValue(
79                "runtime selection is currently not supported".to_owned(),
80            )
81            .into());
82        }
83        None => None,
84    };
85
86    let return_type;
87    let kind = match returns {
88        Some(CreateFunctionReturns::Value(data_type)) => {
89            return_type = bind_data_type(&data_type)?;
90            Kind::Scalar(ScalarFunction {})
91        }
92        Some(CreateFunctionReturns::Table(columns)) => {
93            if columns.len() == 1 {
94                // return type is the original type for single column
95                return_type = bind_data_type(&columns[0].data_type)?;
96            } else {
97                // return type is a struct for multiple columns
98                let it = columns
99                    .into_iter()
100                    .map(|c| bind_data_type(&c.data_type).map(|ty| (c.name.real_value(), ty)));
101                let fields = it.try_collect::<_, Vec<_>, _>()?;
102                return_type = StructType::new(fields).into();
103            }
104            Kind::Table(TableFunction {})
105        }
106        None => {
107            return Err(ErrorCode::InvalidParameterValue(
108                "return type must be specified".to_owned(),
109            )
110            .into());
111        }
112    };
113
114    let mut arg_names = vec![];
115    let mut arg_types = vec![];
116    for arg in args.unwrap_or_default() {
117        arg_names.push(arg.name.map_or("".to_owned(), |n| n.real_value()));
118        arg_types.push(bind_data_type(&arg.data_type)?);
119    }
120
121    // resolve database and schema id
122    let session = &handler_args.session;
123    let db_name = &session.database();
124    let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, &name)?;
125    let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;
126
127    // check if the function exists in the catalog
128    if let Either::Right(resp) = session.check_function_name_duplicated(
129        StatementType::CREATE_FUNCTION,
130        name,
131        &arg_types,
132        if_not_exists,
133    )? {
134        return Ok(resp);
135    }
136
137    let link = match &params.using {
138        Some(CreateFunctionUsing::Link(l)) => Some(l.as_str()),
139        _ => None,
140    };
141    let base64_decoded = match &params.using {
142        Some(CreateFunctionUsing::Base64(encoded)) => {
143            use base64::prelude::{BASE64_STANDARD, Engine};
144            let bytes = BASE64_STANDARD
145                .decode(encoded)
146                .context("invalid base64 encoding")?;
147            Some(bytes)
148        }
149        _ => None,
150    };
151
152    let create_fn =
153        risingwave_expr::sig::find_udf_impl(&language, runtime.as_deref(), link)?.create_fn;
154    let output = create_fn(CreateOptions {
155        kind: match kind {
156            Kind::Scalar(_) => UdfKind::Scalar,
157            Kind::Table(_) => UdfKind::Table,
158            Kind::Aggregate(_) => unreachable!(),
159        },
160        name: &function_name,
161        arg_names: &arg_names,
162        arg_types: &arg_types,
163        return_type: &return_type,
164        as_: params.as_.as_ref().map(|s| s.as_str()),
165        using_link: link,
166        using_base64_decoded: base64_decoded.as_deref(),
167    })?;
168
169    let function = PbFunction {
170        id: FunctionId::placeholder().0,
171        schema_id,
172        database_id,
173        name: function_name,
174        kind: Some(kind),
175        arg_names,
176        arg_types: arg_types.into_iter().map(|t| t.into()).collect(),
177        return_type: Some(return_type.into()),
178        language,
179        runtime,
180        name_in_runtime: Some(output.name_in_runtime),
181        link: link.map(|s| s.to_owned()),
182        body: output.body,
183        compressed_binary: output.compressed_binary,
184        owner: session.user_id(),
185        always_retry_on_network_error: with_options
186            .always_retry_on_network_error
187            .unwrap_or_default(),
188        is_async: with_options.r#async,
189        is_batched: with_options.batch,
190        created_at_epoch: None,
191        created_at_cluster_version: None,
192    };
193
194    let catalog_writer = session.catalog_writer()?;
195    catalog_writer.create_function(function).await?;
196
197    Ok(PgResponse::empty_result(StatementType::CREATE_FUNCTION))
198}