risingwave_frontend/handler/
create_aggregate.rsuse anyhow::Context;
use either::Either;
use risingwave_common::catalog::FunctionId;
use risingwave_expr::sig::{CreateFunctionOptions, UdfKind};
use risingwave_pb::catalog::function::{AggregateFunction, Kind};
use risingwave_pb::catalog::Function;
use risingwave_sqlparser::ast::DataType as AstDataType;
use super::*;
use crate::{bind_data_type, Binder};
pub async fn handle_create_aggregate(
handler_args: HandlerArgs,
or_replace: bool,
if_not_exists: bool,
name: ObjectName,
args: Vec<OperateFunctionArg>,
returns: AstDataType,
params: CreateFunctionBody,
) -> Result<RwPgResponse> {
if or_replace {
bail_not_implemented!("CREATE OR REPLACE AGGREGATE");
}
let language = match params.language {
Some(lang) => {
let lang = lang.real_value().to_lowercase();
match &*lang {
"python" | "javascript" => lang,
_ => {
return Err(ErrorCode::InvalidParameterValue(format!(
"language {} is not supported",
lang
))
.into())
}
}
}
None => return Err(ErrorCode::InvalidParameterValue("no language".into()).into()),
};
let runtime = match params.runtime {
Some(_) => {
return Err(ErrorCode::InvalidParameterValue(
"runtime selection is currently not supported".to_owned(),
)
.into());
}
None => None,
};
let return_type = bind_data_type(&returns)?;
let mut arg_names = vec![];
let mut arg_types = vec![];
for arg in args {
arg_names.push(arg.name.map_or("".to_owned(), |n| n.real_value()));
arg_types.push(bind_data_type(&arg.data_type)?);
}
let session = &handler_args.session;
let db_name = &session.database();
let (schema_name, function_name) =
Binder::resolve_schema_qualified_name(db_name, name.clone())?;
let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;
if let Either::Right(resp) = session.check_function_name_duplicated(
StatementType::CREATE_FUNCTION,
name,
&arg_types,
if_not_exists,
)? {
return Ok(resp);
}
let link = match ¶ms.using {
Some(CreateFunctionUsing::Link(l)) => Some(l.as_str()),
_ => None,
};
let base64_decoded = match ¶ms.using {
Some(CreateFunctionUsing::Base64(encoded)) => {
use base64::prelude::{Engine, BASE64_STANDARD};
let bytes = BASE64_STANDARD
.decode(encoded)
.context("invalid base64 encoding")?;
Some(bytes)
}
_ => None,
};
let create_fn = risingwave_expr::sig::find_udf_impl(&language, None, link)?.create_fn;
let output = create_fn(CreateFunctionOptions {
kind: UdfKind::Aggregate,
name: &function_name,
arg_names: &arg_names,
arg_types: &arg_types,
return_type: &return_type,
as_: params.as_.as_ref().map(|s| s.as_str()),
using_link: link,
using_base64_decoded: base64_decoded.as_deref(),
})?;
let function = Function {
id: FunctionId::placeholder().0,
schema_id,
database_id,
name: function_name,
kind: Some(Kind::Aggregate(AggregateFunction {})),
arg_names,
arg_types: arg_types.into_iter().map(|t| t.into()).collect(),
return_type: Some(return_type.into()),
language,
runtime,
identifier: Some(output.identifier),
link: link.map(|s| s.to_owned()),
body: output.body,
compressed_binary: output.compressed_binary,
owner: session.user_id(),
always_retry_on_network_error: false,
};
let catalog_writer = session.catalog_writer()?;
catalog_writer.create_function(function).await?;
Ok(PgResponse::empty_result(StatementType::CREATE_AGGREGATE))
}