risingwave_frontend/handler/
create_function.rsuse anyhow::Context;
use risingwave_common::catalog::FunctionId;
use risingwave_common::types::DataType;
use risingwave_expr::sig::{CreateFunctionOptions, UdfKind};
use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction};
use risingwave_pb::catalog::Function;
use super::*;
use crate::catalog::CatalogError;
use crate::{bind_data_type, Binder};
pub async fn handle_create_function(
handler_args: HandlerArgs,
or_replace: bool,
temporary: bool,
name: ObjectName,
args: Option<Vec<OperateFunctionArg>>,
returns: Option<CreateFunctionReturns>,
params: CreateFunctionBody,
with_options: CreateFunctionWithOptions,
) -> Result<RwPgResponse> {
if or_replace {
bail_not_implemented!("CREATE OR REPLACE FUNCTION");
}
if temporary {
bail_not_implemented!("CREATE TEMPORARY FUNCTION");
}
let language = match params.language {
Some(lang) => {
let lang = lang.real_value().to_lowercase();
match &*lang {
"python" | "java" | "wasm" | "rust" | "javascript" => lang,
_ => {
return Err(ErrorCode::InvalidParameterValue(format!(
"language {} is not supported",
lang
))
.into())
}
}
}
None => "".to_string(),
};
let runtime = match params.runtime {
Some(_) => {
return Err(ErrorCode::InvalidParameterValue(
"runtime selection is currently not supported".to_string(),
)
.into());
}
None => None,
};
let return_type;
let kind = match returns {
Some(CreateFunctionReturns::Value(data_type)) => {
return_type = bind_data_type(&data_type)?;
Kind::Scalar(ScalarFunction {})
}
Some(CreateFunctionReturns::Table(columns)) => {
if columns.len() == 1 {
return_type = bind_data_type(&columns[0].data_type)?;
} else {
let it = columns
.into_iter()
.map(|c| bind_data_type(&c.data_type).map(|ty| (ty, c.name.real_value())));
let (datatypes, names) = itertools::process_results(it, |it| it.unzip())?;
return_type = DataType::new_struct(datatypes, names);
}
Kind::Table(TableFunction {})
}
None => {
return Err(ErrorCode::InvalidParameterValue(
"return type must be specified".to_string(),
)
.into())
}
};
let mut arg_names = vec![];
let mut arg_types = vec![];
for arg in args.unwrap_or_default() {
arg_names.push(arg.name.map_or("".to_string(), |n| n.real_value()));
arg_types.push(bind_data_type(&arg.data_type)?);
}
let session = &handler_args.session;
let db_name = session.database();
let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?;
let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;
if (session.env().catalog_reader().read_guard())
.get_schema_by_id(&database_id, &schema_id)?
.get_function_by_name_args(&function_name, &arg_types)
.is_some()
{
let name = format!(
"{function_name}({})",
arg_types.iter().map(|t| t.to_string()).join(",")
);
return Err(CatalogError::Duplicated("function", name).into());
}
let link = match ¶ms.using {
Some(CreateFunctionUsing::Link(l)) => Some(l.as_str()),
_ => None,
};
let base64_decoded = match ¶ms.using {
Some(CreateFunctionUsing::Base64(encoded)) => {
use base64::prelude::{Engine, BASE64_STANDARD};
let bytes = BASE64_STANDARD
.decode(encoded)
.context("invalid base64 encoding")?;
Some(bytes)
}
_ => None,
};
let create_fn =
risingwave_expr::sig::find_udf_impl(&language, runtime.as_deref(), link)?.create_fn;
let output = create_fn(CreateFunctionOptions {
kind: match kind {
Kind::Scalar(_) => UdfKind::Scalar,
Kind::Table(_) => UdfKind::Table,
Kind::Aggregate(_) => unreachable!(),
},
name: &function_name,
arg_names: &arg_names,
arg_types: &arg_types,
return_type: &return_type,
as_: params.as_.as_ref().map(|s| s.as_str()),
using_link: link,
using_base64_decoded: base64_decoded.as_deref(),
})?;
let function = Function {
id: FunctionId::placeholder().0,
schema_id,
database_id,
name: function_name,
kind: Some(kind),
arg_names,
arg_types: arg_types.into_iter().map(|t| t.into()).collect(),
return_type: Some(return_type.into()),
language,
runtime,
identifier: Some(output.identifier),
link: link.map(|s| s.to_string()),
body: output.body,
compressed_binary: output.compressed_binary,
owner: session.user_id(),
always_retry_on_network_error: with_options
.always_retry_on_network_error
.unwrap_or_default(),
};
let catalog_writer = session.catalog_writer()?;
catalog_writer.create_function(function).await?;
Ok(PgResponse::empty_result(StatementType::CREATE_FUNCTION))
}