risingwave_frontend/handler/
create_aggregate.rs1use anyhow::Context;
16use either::Either;
17use risingwave_common::catalog::FunctionId;
18use risingwave_expr::sig::{CreateOptions, UdfKind};
19use risingwave_pb::catalog::Function;
20use risingwave_pb::catalog::function::{AggregateFunction, Kind};
21use risingwave_sqlparser::ast::DataType as AstDataType;
22
23use super::*;
24use crate::{Binder, bind_data_type};
25
26pub async fn handle_create_aggregate(
27 handler_args: HandlerArgs,
28 or_replace: bool,
29 if_not_exists: bool,
30 name: ObjectName,
31 args: Vec<OperateFunctionArg>,
32 returns: AstDataType,
33 params: CreateFunctionBody,
34) -> Result<RwPgResponse> {
35 if or_replace {
36 bail_not_implemented!("CREATE OR REPLACE AGGREGATE");
37 }
38
39 let udf_config = handler_args.session.env().udf_config();
40
41 let language = match params.language {
43 Some(lang) => {
44 let lang = lang.real_value().to_lowercase();
45 match &*lang {
46 "python" if udf_config.enable_embedded_python_udf => lang,
47 "javascript" if udf_config.enable_embedded_javascript_udf => lang,
48 "python" | "javascript" => {
49 return Err(ErrorCode::InvalidParameterValue(format!(
50 "{} UDF is not enabled in configuration",
51 lang
52 ))
53 .into());
54 }
55 _ => {
56 return Err(ErrorCode::InvalidParameterValue(format!(
57 "language {} is not supported",
58 lang
59 ))
60 .into());
61 }
62 }
63 }
64 None => return Err(ErrorCode::InvalidParameterValue("no language".into()).into()),
65 };
66
67 let runtime = match params.runtime {
68 Some(_) => {
69 return Err(ErrorCode::InvalidParameterValue(
70 "runtime selection is currently not supported".to_owned(),
71 )
72 .into());
73 }
74 None => None,
75 };
76
77 let return_type = bind_data_type(&returns)?;
78
79 let mut arg_names = vec![];
80 let mut arg_types = vec![];
81 for arg in args {
82 arg_names.push(arg.name.map_or("".to_owned(), |n| n.real_value()));
83 arg_types.push(bind_data_type(&arg.data_type)?);
84 }
85
86 let session = &handler_args.session;
88 let db_name = &session.database();
89 let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, &name)?;
90 let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;
91
92 if let Either::Right(resp) = session.check_function_name_duplicated(
94 StatementType::CREATE_FUNCTION,
95 name,
96 &arg_types,
97 if_not_exists,
98 )? {
99 return Ok(resp);
100 }
101
102 let link = match ¶ms.using {
103 Some(CreateFunctionUsing::Link(l)) => Some(l.as_str()),
104 _ => None,
105 };
106 let base64_decoded = match ¶ms.using {
107 Some(CreateFunctionUsing::Base64(encoded)) => {
108 use base64::prelude::{BASE64_STANDARD, Engine};
109 let bytes = BASE64_STANDARD
110 .decode(encoded)
111 .context("invalid base64 encoding")?;
112 Some(bytes)
113 }
114 _ => None,
115 };
116
117 let create_fn = risingwave_expr::sig::find_udf_impl(&language, None, link)?.create_fn;
118 let output = create_fn(CreateOptions {
119 kind: UdfKind::Aggregate,
120 name: &function_name,
121 arg_names: &arg_names,
122 arg_types: &arg_types,
123 return_type: &return_type,
124 as_: params.as_.as_ref().map(|s| s.as_str()),
125 using_link: link,
126 using_base64_decoded: base64_decoded.as_deref(),
127 })?;
128
129 let function = Function {
130 id: FunctionId::placeholder().0,
131 schema_id,
132 database_id,
133 name: function_name,
134 kind: Some(Kind::Aggregate(AggregateFunction {})),
135 arg_names,
136 arg_types: arg_types.into_iter().map(|t| t.into()).collect(),
137 return_type: Some(return_type.into()),
138 language,
139 runtime,
140 name_in_runtime: Some(output.name_in_runtime),
141 link: link.map(|s| s.to_owned()),
142 body: output.body,
143 compressed_binary: output.compressed_binary,
144 owner: session.user_id(),
145 always_retry_on_network_error: false,
146 is_async: None,
147 is_batched: None,
148 created_at_epoch: None,
149 created_at_cluster_version: None,
150 };
151
152 let catalog_writer = session.catalog_writer()?;
153 catalog_writer.create_function(function).await?;
154
155 Ok(PgResponse::empty_result(StatementType::CREATE_AGGREGATE))
156}