risingwave_frontend/handler/
create_aggregate.rs1use anyhow::Context;
16use either::Either;
17use risingwave_common::catalog::FunctionId;
18use risingwave_expr::sig::{CreateOptions, UdfKind};
19use risingwave_pb::catalog::Function;
20use risingwave_pb::catalog::function::{AggregateFunction, Kind};
21use risingwave_sqlparser::ast::DataType as AstDataType;
22
23use super::*;
24use crate::{Binder, bind_data_type};
25
26pub async fn handle_create_aggregate(
27 handler_args: HandlerArgs,
28 or_replace: bool,
29 if_not_exists: bool,
30 name: ObjectName,
31 args: Vec<OperateFunctionArg>,
32 returns: AstDataType,
33 params: CreateFunctionBody,
34) -> Result<RwPgResponse> {
35 if or_replace {
36 bail_not_implemented!("CREATE OR REPLACE AGGREGATE");
37 }
38
39 let udf_config = handler_args.session.env().udf_config();
40
41 let language = match params.language {
43 Some(lang) => {
44 let lang = lang.real_value().to_lowercase();
45 match &*lang {
46 "python" if udf_config.enable_embedded_python_udf => lang,
47 "javascript" if udf_config.enable_embedded_javascript_udf => lang,
48 "python" | "javascript" => {
49 return Err(ErrorCode::InvalidParameterValue(format!(
50 "{} UDF is not enabled in configuration",
51 lang
52 ))
53 .into());
54 }
55 _ => {
56 return Err(ErrorCode::InvalidParameterValue(format!(
57 "language {} is not supported",
58 lang
59 ))
60 .into());
61 }
62 }
63 }
64 None => return Err(ErrorCode::InvalidParameterValue("no language".into()).into()),
65 };
66
67 let runtime = match params.runtime {
68 Some(_) => {
69 return Err(ErrorCode::InvalidParameterValue(
70 "runtime selection is currently not supported".to_owned(),
71 )
72 .into());
73 }
74 None => None,
75 };
76
77 let return_type = bind_data_type(&returns)?;
78
79 let mut arg_names = vec![];
80 let mut arg_types = vec![];
81 for arg in args {
82 arg_names.push(arg.name.map_or("".to_owned(), |n| n.real_value()));
83 arg_types.push(bind_data_type(&arg.data_type)?);
84 }
85
86 let session = &handler_args.session;
88 let db_name = &session.database();
89 let (schema_name, function_name) =
90 Binder::resolve_schema_qualified_name(db_name, name.clone())?;
91 let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;
92
93 if let Either::Right(resp) = session.check_function_name_duplicated(
95 StatementType::CREATE_FUNCTION,
96 name,
97 &arg_types,
98 if_not_exists,
99 )? {
100 return Ok(resp);
101 }
102
103 let link = match ¶ms.using {
104 Some(CreateFunctionUsing::Link(l)) => Some(l.as_str()),
105 _ => None,
106 };
107 let base64_decoded = match ¶ms.using {
108 Some(CreateFunctionUsing::Base64(encoded)) => {
109 use base64::prelude::{BASE64_STANDARD, Engine};
110 let bytes = BASE64_STANDARD
111 .decode(encoded)
112 .context("invalid base64 encoding")?;
113 Some(bytes)
114 }
115 _ => None,
116 };
117
118 let create_fn = risingwave_expr::sig::find_udf_impl(&language, None, link)?.create_fn;
119 let output = create_fn(CreateOptions {
120 kind: UdfKind::Aggregate,
121 name: &function_name,
122 arg_names: &arg_names,
123 arg_types: &arg_types,
124 return_type: &return_type,
125 as_: params.as_.as_ref().map(|s| s.as_str()),
126 using_link: link,
127 using_base64_decoded: base64_decoded.as_deref(),
128 })?;
129
130 let function = Function {
131 id: FunctionId::placeholder().0,
132 schema_id,
133 database_id,
134 name: function_name,
135 kind: Some(Kind::Aggregate(AggregateFunction {})),
136 arg_names,
137 arg_types: arg_types.into_iter().map(|t| t.into()).collect(),
138 return_type: Some(return_type.into()),
139 language,
140 runtime,
141 name_in_runtime: Some(output.name_in_runtime),
142 link: link.map(|s| s.to_owned()),
143 body: output.body,
144 compressed_binary: output.compressed_binary,
145 owner: session.user_id(),
146 always_retry_on_network_error: false,
147 is_async: None,
148 is_batched: None,
149 };
150
151 let catalog_writer = session.catalog_writer()?;
152 catalog_writer.create_function(function).await?;
153
154 Ok(PgResponse::empty_result(StatementType::CREATE_AGGREGATE))
155}