risingwave_frontend/handler/
create_aggregate.rsuse anyhow::Context;
use risingwave_common::catalog::FunctionId;
use risingwave_expr::sig::{CreateFunctionOptions, UdfKind};
use risingwave_pb::catalog::function::{AggregateFunction, Kind};
use risingwave_pb::catalog::Function;
use risingwave_sqlparser::ast::DataType as AstDataType;
use super::*;
use crate::catalog::CatalogError;
use crate::{bind_data_type, Binder};
pub async fn handle_create_aggregate(
handler_args: HandlerArgs,
or_replace: bool,
name: ObjectName,
args: Vec<OperateFunctionArg>,
returns: AstDataType,
params: CreateFunctionBody,
) -> Result<RwPgResponse> {
if or_replace {
bail_not_implemented!("CREATE OR REPLACE AGGREGATE");
}
let language = match params.language {
Some(lang) => {
let lang = lang.real_value().to_lowercase();
match &*lang {
"python" | "javascript" => lang,
_ => {
return Err(ErrorCode::InvalidParameterValue(format!(
"language {} is not supported",
lang
))
.into())
}
}
}
None => return Err(ErrorCode::InvalidParameterValue("no language".into()).into()),
};
let runtime = match params.runtime {
Some(_) => {
return Err(ErrorCode::InvalidParameterValue(
"runtime selection is currently not supported".to_string(),
)
.into());
}
None => None,
};
let return_type = bind_data_type(&returns)?;
let mut arg_names = vec![];
let mut arg_types = vec![];
for arg in args {
arg_names.push(arg.name.map_or("".to_string(), |n| n.real_value()));
arg_types.push(bind_data_type(&arg.data_type)?);
}
let session = &handler_args.session;
let db_name = session.database();
let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?;
let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;
if (session.env().catalog_reader().read_guard())
.get_schema_by_id(&database_id, &schema_id)?
.get_function_by_name_args(&function_name, &arg_types)
.is_some()
{
let name = format!(
"{function_name}({})",
arg_types.iter().map(|t| t.to_string()).join(",")
);
return Err(CatalogError::Duplicated("function", name).into());
}
let link = match ¶ms.using {
Some(CreateFunctionUsing::Link(l)) => Some(l.as_str()),
_ => None,
};
let base64_decoded = match ¶ms.using {
Some(CreateFunctionUsing::Base64(encoded)) => {
use base64::prelude::{Engine, BASE64_STANDARD};
let bytes = BASE64_STANDARD
.decode(encoded)
.context("invalid base64 encoding")?;
Some(bytes)
}
_ => None,
};
let create_fn = risingwave_expr::sig::find_udf_impl(&language, None, link)?.create_fn;
let output = create_fn(CreateFunctionOptions {
kind: UdfKind::Aggregate,
name: &function_name,
arg_names: &arg_names,
arg_types: &arg_types,
return_type: &return_type,
as_: params.as_.as_ref().map(|s| s.as_str()),
using_link: link,
using_base64_decoded: base64_decoded.as_deref(),
})?;
let function = Function {
id: FunctionId::placeholder().0,
schema_id,
database_id,
name: function_name,
kind: Some(Kind::Aggregate(AggregateFunction {})),
arg_names,
arg_types: arg_types.into_iter().map(|t| t.into()).collect(),
return_type: Some(return_type.into()),
language,
runtime,
identifier: Some(output.identifier),
link: link.map(|s| s.to_string()),
body: output.body,
compressed_binary: output.compressed_binary,
owner: session.user_id(),
always_retry_on_network_error: false,
};
let catalog_writer = session.catalog_writer()?;
catalog_writer.create_function(function).await?;
Ok(PgResponse::empty_result(StatementType::CREATE_AGGREGATE))
}