risingwave_frontend/handler/
create_sql_function.rs1use std::collections::HashMap;
16
17use either::Either;
18use fancy_regex::Regex;
19use risingwave_common::catalog::FunctionId;
20use risingwave_common::types::{DataType, StructType};
21use risingwave_pb::catalog::PbFunction;
22use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction};
23use risingwave_sqlparser::parser::{Parser, ParserError};
24
25use super::*;
26use crate::binder::UdfContext;
27use crate::expr::{Expr, ExprImpl, Literal};
28use crate::{Binder, bind_data_type};
29
30enum ErrMsgType {
34 Parameter,
35 Function,
36 None,
38}
39
40const DEFAULT_ERR_MSG: &str = "Failed to conduct semantic check";
41
42const PROMPT: &str = "In SQL UDF definition: ";
44
45const FUNCTION_KEYWORD: &str = "function";
47
48pub const SQL_UDF_PATTERN: &str = "[sql udf]";
50
51fn validate_err_msg(invalid_msg: &str) -> ErrMsgType {
54 if invalid_msg.contains(SQL_UDF_PATTERN) {
56 ErrMsgType::Parameter
57 } else if invalid_msg.contains(FUNCTION_KEYWORD) {
58 ErrMsgType::Function
59 } else {
60 ErrMsgType::None
62 }
63}
64
65fn extract_hint_display_target(err_msg_type: ErrMsgType, invalid_msg: &str) -> Option<&str> {
68 match err_msg_type {
69 ErrMsgType::Parameter => invalid_msg.split_whitespace().last(),
71 ErrMsgType::Function => {
73 let func = invalid_msg.split_whitespace().nth(1).unwrap_or("null");
74 func.find('(').map(|i| &func[0..i])
76 }
77 ErrMsgType::None => None,
79 }
80}
81
82fn find_target(input: &str, target: &str) -> Option<usize> {
85 let pattern = format!(r"(?<![A-Za-z]){0}(?![A-Za-z])", fancy_regex::escape(target));
89 let Ok(re) = Regex::new(&pattern) else {
90 return None;
91 };
92
93 let Ok(Some(ma)) = re.find(input) else {
94 return None;
95 };
96
97 Some(ma.start())
98}
99
100fn create_mock_udf_context(
102 arg_types: Vec<DataType>,
103 arg_names: Vec<String>,
104) -> HashMap<String, ExprImpl> {
105 let mut ret: HashMap<String, ExprImpl> = (1..=arg_types.len())
106 .map(|i| {
107 let mock_expr =
108 ExprImpl::Literal(Box::new(Literal::new(None, arg_types[i - 1].clone())));
109 (format!("${i}"), mock_expr)
110 })
111 .collect();
112
113 for (i, arg_name) in arg_names.into_iter().enumerate() {
114 let mock_expr = ExprImpl::Literal(Box::new(Literal::new(None, arg_types[i].clone())));
115 ret.insert(arg_name, mock_expr);
116 }
117
118 ret
119}
120
121pub async fn handle_create_sql_function(
122 handler_args: HandlerArgs,
123 or_replace: bool,
124 temporary: bool,
125 if_not_exists: bool,
126 name: ObjectName,
127 args: Option<Vec<OperateFunctionArg>>,
128 returns: Option<CreateFunctionReturns>,
129 params: CreateFunctionBody,
130) -> Result<RwPgResponse> {
131 if or_replace {
132 bail_not_implemented!("CREATE OR REPLACE FUNCTION");
133 }
134
135 if temporary {
136 bail_not_implemented!("CREATE TEMPORARY FUNCTION");
137 }
138
139 let language = "sql".to_owned();
140
141 if !matches!(params.language, Some(lang) if lang.real_value().to_lowercase() == "sql") {
143 return Err(ErrorCode::InvalidParameterValue(
144 "`language` for sql udf must be `sql`".to_owned(),
145 )
146 .into());
147 }
148
149 let body = match ¶ms.as_ {
152 Some(FunctionDefinition::SingleQuotedDef(s)) => s.clone(),
153 Some(FunctionDefinition::DoubleDollarDef(s)) => s.clone(),
154 Some(FunctionDefinition::Identifier(_)) => {
155 return Err(ErrorCode::InvalidParameterValue("expect quoted string".to_owned()).into());
156 }
157 None => {
158 if params.return_.is_none() {
159 return Err(ErrorCode::InvalidParameterValue(
160 "AS or RETURN must be specified".to_owned(),
161 )
162 .into());
163 }
164 format!("select {}", ¶ms.return_.unwrap().to_string())
168 }
169 };
170
171 if let Some(CreateFunctionUsing::Link(_)) = params.using {
173 return Err(ErrorCode::InvalidParameterValue(
174 "USING must NOT be specified with sql udf function".to_owned(),
175 )
176 .into());
177 };
178
179 let return_type;
181 let kind = match returns {
182 Some(CreateFunctionReturns::Value(data_type)) => {
183 return_type = bind_data_type(&data_type)?;
184 Kind::Scalar(ScalarFunction {})
185 }
186 Some(CreateFunctionReturns::Table(columns)) => {
187 if columns.len() == 1 {
188 return_type = bind_data_type(&columns[0].data_type)?;
190 } else {
191 let fields = columns
193 .iter()
194 .map(|c| Ok((c.name.real_value(), bind_data_type(&c.data_type)?)))
195 .collect::<Result<Vec<_>>>()?;
196 return_type = StructType::new(fields).into();
197 }
198 Kind::Table(TableFunction {})
199 }
200 None => {
201 return Err(ErrorCode::InvalidParameterValue(
202 "return type must be specified".to_owned(),
203 )
204 .into());
205 }
206 };
207
208 let mut arg_names = vec![];
209 let mut arg_types = vec![];
210 for arg in args.unwrap_or_default() {
211 arg_names.push(arg.name.map_or("".to_owned(), |n| n.real_value()));
212 arg_types.push(bind_data_type(&arg.data_type)?);
213 }
214
215 let session = &handler_args.session;
217 let db_name = &session.database();
218 let (schema_name, function_name) =
219 Binder::resolve_schema_qualified_name(db_name, name.clone())?;
220 let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;
221
222 if let Either::Right(resp) = session.check_function_name_duplicated(
224 StatementType::CREATE_FUNCTION,
225 name,
226 &arg_types,
227 if_not_exists,
228 )? {
229 return Ok(resp);
230 }
231
232 let parse_result = Parser::parse_sql(body.as_str());
236 if let Err(ParserError::ParserError(err)) | Err(ParserError::TokenizerError(err)) = parse_result
237 {
238 return Err(ErrorCode::InvalidInputSyntax(err).into());
240 } else {
241 debug_assert!(parse_result.is_ok());
242
243 let ast = parse_result.unwrap();
245 let mut binder = Binder::new_for_system(session);
246
247 binder
248 .udf_context_mut()
249 .update_context(create_mock_udf_context(
250 arg_types.clone(),
251 arg_names.clone(),
252 ));
253
254 binder.udf_context_mut().incr_global_count();
257
258 if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
259 match binder.bind_expr(expr) {
260 Ok(expr) => {
261 if expr.return_type() != return_type {
263 return Err(ErrorCode::InvalidInputSyntax(format!(
264 "\nreturn type mismatch detected\nexpected: [{}]\nactual: [{}]\nplease adjust your function definition accordingly",
265 return_type,
266 expr.return_type()
267 ))
268 .into());
269 }
270 }
271 Err(e) => {
272 if let ErrorCode::BindErrorRoot { expr: _, error } = e.inner() {
273 let invalid_msg = error.to_string();
274
275 let err_msg_type = validate_err_msg(invalid_msg.as_str());
277
278 let Some(invalid_item_name) =
281 extract_hint_display_target(err_msg_type, invalid_msg.as_str())
282 else {
283 return Err(
284 ErrorCode::InvalidInputSyntax(DEFAULT_ERR_MSG.into()).into()
285 );
286 };
287
288 let Some(idx) = find_target(body.as_str(), invalid_item_name) else {
290 return Err(
291 ErrorCode::InvalidInputSyntax(DEFAULT_ERR_MSG.into()).into()
292 );
293 };
294
295 let position = format!(
297 "{}{}",
298 " ".repeat(idx + PROMPT.len() + 1),
299 "^".repeat(invalid_item_name.len())
300 );
301
302 return Err(ErrorCode::InvalidInputSyntax(format!(
303 "{}\n{}\n{}`{}`\n{}",
304 DEFAULT_ERR_MSG, invalid_msg, PROMPT, body, position
305 ))
306 .into());
307 }
308
309 return Err(ErrorCode::InvalidInputSyntax(DEFAULT_ERR_MSG.into()).into());
311 }
312 }
313 } else {
314 return Err(ErrorCode::InvalidInputSyntax(
315 "failed to parse the input query and extract the udf expression,
316 please recheck the syntax"
317 .to_owned(),
318 )
319 .into());
320 }
321 }
322
323 let function = PbFunction {
325 id: FunctionId::placeholder().0,
326 schema_id,
327 database_id,
328 name: function_name,
329 kind: Some(kind),
330 arg_names,
331 arg_types: arg_types.into_iter().map(|t| t.into()).collect(),
332 return_type: Some(return_type.into()),
333 language,
334 runtime: None,
335 name_in_runtime: None, body: Some(body),
337 compressed_binary: None,
338 link: None,
339 owner: session.user_id(),
340 always_retry_on_network_error: false,
341 is_async: None,
342 is_batched: None,
343 };
344
345 let catalog_writer = session.catalog_writer()?;
346 catalog_writer.create_function(function).await?;
347
348 Ok(PgResponse::empty_result(StatementType::CREATE_FUNCTION))
349}