risingwave_frontend/handler/
create_sql_function.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use either::Either;
16use risingwave_common::catalog::FunctionId;
17use risingwave_common::types::StructType;
18use risingwave_pb::catalog::PbFunction;
19use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction};
20
21use super::*;
22use crate::expr::{Expr, Literal};
23use crate::{Binder, bind_data_type};
24
25pub async fn handle_create_sql_function(
26    handler_args: HandlerArgs,
27    or_replace: bool,
28    temporary: bool,
29    if_not_exists: bool,
30    name: ObjectName,
31    args: Option<Vec<OperateFunctionArg>>,
32    returns: Option<CreateFunctionReturns>,
33    params: CreateFunctionBody,
34) -> Result<RwPgResponse> {
35    if or_replace {
36        bail_not_implemented!("CREATE OR REPLACE FUNCTION");
37    }
38
39    if temporary {
40        bail_not_implemented!("CREATE TEMPORARY FUNCTION");
41    }
42
43    let language = "sql".to_owned();
44
45    // Just a basic sanity check for `language`
46    if !matches!(params.language, Some(lang) if lang.real_value().to_lowercase() == "sql") {
47        return Err(ErrorCode::InvalidParameterValue(
48            "`language` for sql udf must be `sql`".to_owned(),
49        )
50        .into());
51    }
52
53    // SQL udf function supports both single quote (i.e., as 'select $1 + $2')
54    // and double dollar (i.e., as $$select $1 + $2$$) for as clause
55    let body = match &params.as_ {
56        Some(FunctionDefinition::SingleQuotedDef(s)) => s.clone(),
57        Some(FunctionDefinition::DoubleDollarDef(s)) => s.clone(),
58        Some(FunctionDefinition::Identifier(_)) => {
59            return Err(ErrorCode::InvalidParameterValue("expect quoted string".to_owned()).into());
60        }
61        None => {
62            if params.return_.is_none() {
63                return Err(ErrorCode::InvalidParameterValue(
64                    "AS or RETURN must be specified".to_owned(),
65                )
66                .into());
67            }
68            // Otherwise this is a return expression
69            // Note: this is a current work around, and we are assuming return sql udf
70            // will NOT involve complex syntax, so just reuse the logic for select definition
71            format!("select {}", &params.return_.unwrap().to_string())
72        }
73    };
74
75    // Sanity check for link, this must be none with sql udf function
76    if let Some(CreateFunctionUsing::Link(_)) = params.using {
77        return Err(ErrorCode::InvalidParameterValue(
78            "USING must NOT be specified with sql udf function".to_owned(),
79        )
80        .into());
81    };
82
83    // Get return type for the current sql udf function
84    let return_type;
85    let kind = match returns {
86        Some(CreateFunctionReturns::Value(data_type)) => {
87            return_type = bind_data_type(&data_type)?;
88            Kind::Scalar(ScalarFunction {})
89        }
90        Some(CreateFunctionReturns::Table(columns)) => {
91            if columns.len() == 1 {
92                // return type is the original type for single column
93                return_type = bind_data_type(&columns[0].data_type)?;
94            } else {
95                // return type is a struct for multiple columns
96                let fields = columns
97                    .iter()
98                    .map(|c| Ok((c.name.real_value(), bind_data_type(&c.data_type)?)))
99                    .collect::<Result<Vec<_>>>()?;
100                return_type = StructType::new(fields).into();
101            }
102            Kind::Table(TableFunction {})
103        }
104        None => {
105            return Err(ErrorCode::InvalidParameterValue(
106                "return type must be specified".to_owned(),
107            )
108            .into());
109        }
110    };
111
112    let mut arg_names = vec![];
113    let mut arg_types = vec![];
114    for arg in args.unwrap_or_default() {
115        arg_names.push(arg.name.map_or("".to_owned(), |n| n.real_value()));
116        arg_types.push(bind_data_type(&arg.data_type)?);
117    }
118
119    // resolve database and schema id
120    let session = &handler_args.session;
121    let db_name = &session.database();
122    let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, &name)?;
123    let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;
124
125    // check if function exists
126    if let Either::Right(resp) = session.check_function_name_duplicated(
127        StatementType::CREATE_FUNCTION,
128        name,
129        &arg_types,
130        if_not_exists,
131    )? {
132        return Ok(resp);
133    }
134
135    // Try bind the function call with mock arguments.
136    // Note that the parsing here is just basic syntax / semantic check, the result will NOT be stored
137    // e.g., The provided function body contains invalid syntax, return type mismatch, ..., etc.
138    {
139        let mut binder = Binder::new_for_system(session);
140        let args = arg_types
141            .iter()
142            .map(|ty| Literal::new(None, ty.clone()).into() /* NULL */)
143            .collect();
144
145        let expr = binder.bind_sql_udf_inner(&body, &arg_names, args)?;
146
147        // Check if the return type mismatches
148        if expr.return_type() != return_type {
149            return Err(ErrorCode::InvalidInputSyntax(format!(
150                "return type mismatch detected\nexpected: [{}]\nactual: [{}]\nplease adjust your function definition accordingly",
151                return_type,
152                expr.return_type()
153            ))
154            .into());
155        }
156    }
157
158    // Create the actual function, will be stored in function catalog
159    let function = PbFunction {
160        id: FunctionId::placeholder().0,
161        schema_id,
162        database_id,
163        name: function_name,
164        kind: Some(kind),
165        arg_names,
166        arg_types: arg_types.into_iter().map(|t| t.into()).collect(),
167        return_type: Some(return_type.into()),
168        language,
169        runtime: None,
170        name_in_runtime: None, // None for SQL UDF
171        body: Some(body),
172        compressed_binary: None,
173        link: None,
174        owner: session.user_id(),
175        always_retry_on_network_error: false,
176        is_async: None,
177        is_batched: None,
178        created_at_epoch: None,
179        created_at_cluster_version: None,
180    };
181
182    let catalog_writer = session.catalog_writer()?;
183    catalog_writer.create_function(function).await?;
184
185    Ok(PgResponse::empty_result(StatementType::CREATE_FUNCTION))
186}