risingwave_frontend/handler/
create_sql_function.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use either::Either;
18use fancy_regex::Regex;
19use risingwave_common::catalog::FunctionId;
20use risingwave_common::types::{DataType, StructType};
21use risingwave_pb::catalog::PbFunction;
22use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction};
23use risingwave_sqlparser::parser::{Parser, ParserError};
24
25use super::*;
26use crate::binder::UdfContext;
27use crate::expr::{Expr, ExprImpl, Literal};
28use crate::{Binder, bind_data_type};
29
30/// The error type for hint display
31/// Currently we will try invalid parameter first
32/// Then try to find non-existent functions
33enum ErrMsgType {
34    Parameter,
35    Function,
36    // Not yet support
37    None,
38}
39
40const DEFAULT_ERR_MSG: &str = "Failed to conduct semantic check";
41
42/// Used for hint display
43const PROMPT: &str = "In SQL UDF definition: ";
44
45/// Used for detecting non-existent function
46const FUNCTION_KEYWORD: &str = "function";
47
48/// Used for detecting invalid parameters
49pub const SQL_UDF_PATTERN: &str = "[sql udf]";
50
51/// Validate the error message to see if
52/// it's possible to improve the display to users
53fn validate_err_msg(invalid_msg: &str) -> ErrMsgType {
54    // First try invalid parameters
55    if invalid_msg.contains(SQL_UDF_PATTERN) {
56        ErrMsgType::Parameter
57    } else if invalid_msg.contains(FUNCTION_KEYWORD) {
58        ErrMsgType::Function
59    } else {
60        // Nothing could be better display
61        ErrMsgType::None
62    }
63}
64
65/// Extract the target name to hint display
66/// according to the type of the error message item
67fn extract_hint_display_target(err_msg_type: ErrMsgType, invalid_msg: &str) -> Option<&str> {
68    match err_msg_type {
69        // e.g., [sql udf] failed to find named parameter <target name>
70        ErrMsgType::Parameter => invalid_msg.split_whitespace().last(),
71        // e.g., function <target name> does not exist
72        ErrMsgType::Function => {
73            let func = invalid_msg.split_whitespace().nth(1).unwrap_or("null");
74            // Note: we do not want the parenthesis
75            func.find('(').map(|i| &func[0..i])
76        }
77        // Nothing to hint display, return default error message
78        ErrMsgType::None => None,
79    }
80}
81
82/// Find the pattern for better hint display
83/// return the exact index where the pattern first appears
84fn find_target(input: &str, target: &str) -> Option<usize> {
85    // Regex pattern to find `target` not preceded or followed by an ASCII letter
86    // The pattern uses negative lookbehind (?<!...) and lookahead (?!...) to ensure
87    // the target is not surrounded by ASCII alphabetic characters
88    let pattern = format!(r"(?<![A-Za-z]){0}(?![A-Za-z])", fancy_regex::escape(target));
89    let Ok(re) = Regex::new(&pattern) else {
90        return None;
91    };
92
93    let Ok(Some(ma)) = re.find(input) else {
94        return None;
95    };
96
97    Some(ma.start())
98}
99
100/// Create a mock `udf_context`, which is used for semantic check
101fn create_mock_udf_context(
102    arg_types: Vec<DataType>,
103    arg_names: Vec<String>,
104) -> HashMap<String, ExprImpl> {
105    let mut ret: HashMap<String, ExprImpl> = (1..=arg_types.len())
106        .map(|i| {
107            let mock_expr =
108                ExprImpl::Literal(Box::new(Literal::new(None, arg_types[i - 1].clone())));
109            (format!("${i}"), mock_expr)
110        })
111        .collect();
112
113    for (i, arg_name) in arg_names.into_iter().enumerate() {
114        let mock_expr = ExprImpl::Literal(Box::new(Literal::new(None, arg_types[i].clone())));
115        ret.insert(arg_name, mock_expr);
116    }
117
118    ret
119}
120
121pub async fn handle_create_sql_function(
122    handler_args: HandlerArgs,
123    or_replace: bool,
124    temporary: bool,
125    if_not_exists: bool,
126    name: ObjectName,
127    args: Option<Vec<OperateFunctionArg>>,
128    returns: Option<CreateFunctionReturns>,
129    params: CreateFunctionBody,
130) -> Result<RwPgResponse> {
131    if or_replace {
132        bail_not_implemented!("CREATE OR REPLACE FUNCTION");
133    }
134
135    if temporary {
136        bail_not_implemented!("CREATE TEMPORARY FUNCTION");
137    }
138
139    let language = "sql".to_owned();
140
141    // Just a basic sanity check for `language`
142    if !matches!(params.language, Some(lang) if lang.real_value().to_lowercase() == "sql") {
143        return Err(ErrorCode::InvalidParameterValue(
144            "`language` for sql udf must be `sql`".to_owned(),
145        )
146        .into());
147    }
148
149    // SQL udf function supports both single quote (i.e., as 'select $1 + $2')
150    // and double dollar (i.e., as $$select $1 + $2$$) for as clause
151    let body = match &params.as_ {
152        Some(FunctionDefinition::SingleQuotedDef(s)) => s.clone(),
153        Some(FunctionDefinition::DoubleDollarDef(s)) => s.clone(),
154        Some(FunctionDefinition::Identifier(_)) => {
155            return Err(ErrorCode::InvalidParameterValue("expect quoted string".to_owned()).into());
156        }
157        None => {
158            if params.return_.is_none() {
159                return Err(ErrorCode::InvalidParameterValue(
160                    "AS or RETURN must be specified".to_owned(),
161                )
162                .into());
163            }
164            // Otherwise this is a return expression
165            // Note: this is a current work around, and we are assuming return sql udf
166            // will NOT involve complex syntax, so just reuse the logic for select definition
167            format!("select {}", &params.return_.unwrap().to_string())
168        }
169    };
170
171    // Sanity check for link, this must be none with sql udf function
172    if let Some(CreateFunctionUsing::Link(_)) = params.using {
173        return Err(ErrorCode::InvalidParameterValue(
174            "USING must NOT be specified with sql udf function".to_owned(),
175        )
176        .into());
177    };
178
179    // Get return type for the current sql udf function
180    let return_type;
181    let kind = match returns {
182        Some(CreateFunctionReturns::Value(data_type)) => {
183            return_type = bind_data_type(&data_type)?;
184            Kind::Scalar(ScalarFunction {})
185        }
186        Some(CreateFunctionReturns::Table(columns)) => {
187            if columns.len() == 1 {
188                // return type is the original type for single column
189                return_type = bind_data_type(&columns[0].data_type)?;
190            } else {
191                // return type is a struct for multiple columns
192                let fields = columns
193                    .iter()
194                    .map(|c| Ok((c.name.real_value(), bind_data_type(&c.data_type)?)))
195                    .collect::<Result<Vec<_>>>()?;
196                return_type = StructType::new(fields).into();
197            }
198            Kind::Table(TableFunction {})
199        }
200        None => {
201            return Err(ErrorCode::InvalidParameterValue(
202                "return type must be specified".to_owned(),
203            )
204            .into());
205        }
206    };
207
208    let mut arg_names = vec![];
209    let mut arg_types = vec![];
210    for arg in args.unwrap_or_default() {
211        arg_names.push(arg.name.map_or("".to_owned(), |n| n.real_value()));
212        arg_types.push(bind_data_type(&arg.data_type)?);
213    }
214
215    // resolve database and schema id
216    let session = &handler_args.session;
217    let db_name = &session.database();
218    let (schema_name, function_name) =
219        Binder::resolve_schema_qualified_name(db_name, name.clone())?;
220    let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;
221
222    // check if function exists
223    if let Either::Right(resp) = session.check_function_name_duplicated(
224        StatementType::CREATE_FUNCTION,
225        name,
226        &arg_types,
227        if_not_exists,
228    )? {
229        return Ok(resp);
230    }
231
232    // Parse function body here
233    // Note that the parsing here is just basic syntax / semantic check, the result will NOT be stored
234    // e.g., The provided function body contains invalid syntax, return type mismatch, ..., etc.
235    let parse_result = Parser::parse_sql(body.as_str());
236    if let Err(ParserError::ParserError(err)) | Err(ParserError::TokenizerError(err)) = parse_result
237    {
238        // Here we just return the original parse error message
239        return Err(ErrorCode::InvalidInputSyntax(err).into());
240    } else {
241        debug_assert!(parse_result.is_ok());
242
243        // Conduct semantic check (e.g., see if the inner calling functions exist, etc.)
244        let ast = parse_result.unwrap();
245        let mut binder = Binder::new_for_system(session);
246
247        binder
248            .udf_context_mut()
249            .update_context(create_mock_udf_context(
250                arg_types.clone(),
251                arg_names.clone(),
252            ));
253
254        // Need to set the initial global count to 1
255        // otherwise the context will not be probed during the semantic check
256        binder.udf_context_mut().incr_global_count();
257
258        if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
259            match binder.bind_expr(expr) {
260                Ok(expr) => {
261                    // Check if the return type mismatches
262                    if expr.return_type() != return_type {
263                        return Err(ErrorCode::InvalidInputSyntax(format!(
264                            "\nreturn type mismatch detected\nexpected: [{}]\nactual: [{}]\nplease adjust your function definition accordingly",
265                            return_type,
266                            expr.return_type()
267                        ))
268                        .into());
269                    }
270                }
271                Err(e) => {
272                    if let ErrorCode::BindErrorRoot { expr: _, error } = e.inner() {
273                        let invalid_msg = error.to_string();
274
275                        // First validate the message
276                        let err_msg_type = validate_err_msg(invalid_msg.as_str());
277
278                        // Get the name of the invalid item
279                        // We will just display the first one found
280                        let Some(invalid_item_name) =
281                            extract_hint_display_target(err_msg_type, invalid_msg.as_str())
282                        else {
283                            return Err(
284                                ErrorCode::InvalidInputSyntax(DEFAULT_ERR_MSG.into()).into()
285                            );
286                        };
287
288                        // Find the invalid parameter / column / function
289                        let Some(idx) = find_target(body.as_str(), invalid_item_name) else {
290                            return Err(
291                                ErrorCode::InvalidInputSyntax(DEFAULT_ERR_MSG.into()).into()
292                            );
293                        };
294
295                        // The exact error position for `^` to point to
296                        let position = format!(
297                            "{}{}",
298                            " ".repeat(idx + PROMPT.len() + 1),
299                            "^".repeat(invalid_item_name.len())
300                        );
301
302                        return Err(ErrorCode::InvalidInputSyntax(format!(
303                            "{}\n{}\n{}`{}`\n{}",
304                            DEFAULT_ERR_MSG, invalid_msg, PROMPT, body, position
305                        ))
306                        .into());
307                    }
308
309                    // Otherwise return the default error message
310                    return Err(ErrorCode::InvalidInputSyntax(DEFAULT_ERR_MSG.into()).into());
311                }
312            }
313        } else {
314            return Err(ErrorCode::InvalidInputSyntax(
315                "failed to parse the input query and extract the udf expression,
316                please recheck the syntax"
317                    .to_owned(),
318            )
319            .into());
320        }
321    }
322
323    // Create the actual function, will be stored in function catalog
324    let function = PbFunction {
325        id: FunctionId::placeholder().0,
326        schema_id,
327        database_id,
328        name: function_name,
329        kind: Some(kind),
330        arg_names,
331        arg_types: arg_types.into_iter().map(|t| t.into()).collect(),
332        return_type: Some(return_type.into()),
333        language,
334        runtime: None,
335        name_in_runtime: None, // None for SQL UDF
336        body: Some(body),
337        compressed_binary: None,
338        link: None,
339        owner: session.user_id(),
340        always_retry_on_network_error: false,
341        is_async: None,
342        is_batched: None,
343    };
344
345    let catalog_writer = session.catalog_writer()?;
346    catalog_writer.create_function(function).await?;
347
348    Ok(PgResponse::empty_result(StatementType::CREATE_FUNCTION))
349}