risingwave_frontend/handler/
create_aggregate.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::Context;
16use either::Either;
17use risingwave_common::catalog::FunctionId;
18use risingwave_expr::sig::{CreateOptions, UdfKind};
19use risingwave_pb::catalog::Function;
20use risingwave_pb::catalog::function::{AggregateFunction, Kind};
21use risingwave_sqlparser::ast::DataType as AstDataType;
22
23use super::*;
24use crate::{Binder, bind_data_type};
25
26pub async fn handle_create_aggregate(
27    handler_args: HandlerArgs,
28    or_replace: bool,
29    if_not_exists: bool,
30    name: ObjectName,
31    args: Vec<OperateFunctionArg>,
32    returns: AstDataType,
33    params: CreateFunctionBody,
34) -> Result<RwPgResponse> {
35    if or_replace {
36        bail_not_implemented!("CREATE OR REPLACE AGGREGATE");
37    }
38
39    let udf_config = handler_args.session.env().udf_config();
40
41    // e.g., `language [ python / java / ...etc]`
42    let language = match params.language {
43        Some(lang) => {
44            let lang = lang.real_value().to_lowercase();
45            match &*lang {
46                "python" if udf_config.enable_embedded_python_udf => lang,
47                "javascript" if udf_config.enable_embedded_javascript_udf => lang,
48                "python" | "javascript" => {
49                    return Err(ErrorCode::InvalidParameterValue(format!(
50                        "{} UDF is not enabled in configuration",
51                        lang
52                    ))
53                    .into());
54                }
55                _ => {
56                    return Err(ErrorCode::InvalidParameterValue(format!(
57                        "language {} is not supported",
58                        lang
59                    ))
60                    .into());
61                }
62            }
63        }
64        None => return Err(ErrorCode::InvalidParameterValue("no language".into()).into()),
65    };
66
67    let runtime = match params.runtime {
68        Some(_) => {
69            return Err(ErrorCode::InvalidParameterValue(
70                "runtime selection is currently not supported".to_owned(),
71            )
72            .into());
73        }
74        None => None,
75    };
76
77    let return_type = bind_data_type(&returns)?;
78
79    let mut arg_names = vec![];
80    let mut arg_types = vec![];
81    for arg in args {
82        arg_names.push(arg.name.map_or("".to_owned(), |n| n.real_value()));
83        arg_types.push(bind_data_type(&arg.data_type)?);
84    }
85
86    // resolve database and schema id
87    let session = &handler_args.session;
88    let db_name = &session.database();
89    let (schema_name, function_name) =
90        Binder::resolve_schema_qualified_name(db_name, name.clone())?;
91    let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;
92
93    // check if the function exists in the catalog
94    if let Either::Right(resp) = session.check_function_name_duplicated(
95        StatementType::CREATE_FUNCTION,
96        name,
97        &arg_types,
98        if_not_exists,
99    )? {
100        return Ok(resp);
101    }
102
103    let link = match &params.using {
104        Some(CreateFunctionUsing::Link(l)) => Some(l.as_str()),
105        _ => None,
106    };
107    let base64_decoded = match &params.using {
108        Some(CreateFunctionUsing::Base64(encoded)) => {
109            use base64::prelude::{BASE64_STANDARD, Engine};
110            let bytes = BASE64_STANDARD
111                .decode(encoded)
112                .context("invalid base64 encoding")?;
113            Some(bytes)
114        }
115        _ => None,
116    };
117
118    let create_fn = risingwave_expr::sig::find_udf_impl(&language, None, link)?.create_fn;
119    let output = create_fn(CreateOptions {
120        kind: UdfKind::Aggregate,
121        name: &function_name,
122        arg_names: &arg_names,
123        arg_types: &arg_types,
124        return_type: &return_type,
125        as_: params.as_.as_ref().map(|s| s.as_str()),
126        using_link: link,
127        using_base64_decoded: base64_decoded.as_deref(),
128    })?;
129
130    let function = Function {
131        id: FunctionId::placeholder().0,
132        schema_id,
133        database_id,
134        name: function_name,
135        kind: Some(Kind::Aggregate(AggregateFunction {})),
136        arg_names,
137        arg_types: arg_types.into_iter().map(|t| t.into()).collect(),
138        return_type: Some(return_type.into()),
139        language,
140        runtime,
141        name_in_runtime: Some(output.name_in_runtime),
142        link: link.map(|s| s.to_owned()),
143        body: output.body,
144        compressed_binary: output.compressed_binary,
145        owner: session.user_id(),
146        always_retry_on_network_error: false,
147        is_async: None,
148        is_batched: None,
149    };
150
151    let catalog_writer = session.catalog_writer()?;
152    catalog_writer.create_function(function).await?;
153
154    Ok(PgResponse::empty_result(StatementType::CREATE_AGGREGATE))
155}