risingwave_frontend/binder/expr/function/
mod.rsuse std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::Arc;
use anyhow::Context;
use itertools::Itertools;
use risingwave_common::bail_not_implemented;
use risingwave_common::catalog::{INFORMATION_SCHEMA_SCHEMA_NAME, PG_CATALOG_SCHEMA_NAME};
use risingwave_common::types::DataType;
use risingwave_expr::aggregate::AggType;
use risingwave_expr::window_function::WindowFuncKind;
use risingwave_sqlparser::ast::{self, Function, FunctionArg, FunctionArgExpr, Ident};
use risingwave_sqlparser::parser::ParserError;
use crate::binder::bind_context::Clause;
use crate::binder::{Binder, UdfContext};
use crate::catalog::function_catalog::FunctionCatalog;
use crate::error::{ErrorCode, Result, RwError};
use crate::expr::{
Expr, ExprImpl, ExprType, FunctionCallWithLambda, InputRef, TableFunction, TableFunctionType,
UserDefinedFunction,
};
mod aggregate;
mod builtin_scalar;
mod window;
const SYS_FUNCTION_WITHOUT_ARGS: &[&str] = &[
"session_user",
"user",
"current_user",
"current_role",
"current_catalog",
"current_schema",
"current_timestamp",
];
pub(super) fn is_sys_function_without_args(ident: &Ident) -> bool {
SYS_FUNCTION_WITHOUT_ARGS
.iter()
.any(|e| ident.real_value().as_str() == *e && ident.quote_style().is_none())
}
const SQL_UDF_MAX_CALLING_DEPTH: u32 = 16;
macro_rules! reject_syntax {
($pred:expr, $msg:expr) => {
if $pred {
return Err(ErrorCode::InvalidInputSyntax($msg.to_string()).into());
}
};
}
impl Binder {
pub(in crate::binder) fn bind_function(
&mut self,
Function {
scalar_as_agg,
name,
arg_list,
within_group,
filter,
over,
}: Function,
) -> Result<ExprImpl> {
let func_name = match name.0.as_slice() {
[name] => name.real_value(),
[schema, name] => {
let schema_name = schema.real_value();
if schema_name == PG_CATALOG_SCHEMA_NAME {
name.real_value()
} else if schema_name == INFORMATION_SCHEMA_SCHEMA_NAME {
let function_name = name.real_value();
if function_name != "_pg_expandarray" {
bail_not_implemented!(
issue = 12422,
"Unsupported function name under schema: {}",
schema_name
);
}
function_name
} else {
bail_not_implemented!(
issue = 12422,
"Unsupported function name under schema: {}",
schema_name
);
}
}
_ => bail_not_implemented!(issue = 112, "qualified function {}", name),
};
if func_name == "obj_description" || func_name == "col_description" {
return Ok(ExprImpl::literal_varchar("".to_string()));
}
if func_name == "array_transform" {
reject_syntax!(
scalar_as_agg,
"`AGGREGATE:` prefix is not allowed for `array_transform`"
);
reject_syntax!(!arg_list.is_args_only(), "keywords like `DISTINCT`, `ORDER BY` are not allowed in `array_transform` argument list");
reject_syntax!(
within_group.is_some(),
"`WITHIN GROUP` is not allowed in `array_transform` call"
);
reject_syntax!(
filter.is_some(),
"`FILTER` is not allowed in `array_transform` call"
);
reject_syntax!(
over.is_some(),
"`OVER` is not allowed in `array_transform` call"
);
return self.bind_array_transform(arg_list.args);
}
let mut args: Vec<_> = arg_list
.args
.iter()
.map(|arg| self.bind_function_arg(arg.clone()))
.flatten_ok()
.try_collect()?;
let mut referred_udfs = HashSet::new();
let wrapped_agg_type = if scalar_as_agg {
let mut array_args = args
.iter()
.enumerate()
.map(|(i, expr)| {
InputRef::new(i, DataType::List(Box::new(expr.return_type()))).into()
})
.collect_vec();
let scalar_func_expr = if let Ok(schema) = self.first_valid_schema()
&& let Some(func) = schema.get_function_by_name_inputs(&func_name, &mut array_args)
{
referred_udfs.insert(func.id);
if !func.kind.is_scalar() {
return Err(ErrorCode::InvalidInputSyntax(
"expect a scalar function after `AGGREGATE:`".to_string(),
)
.into());
}
if func.language == "sql" {
self.bind_sql_udf(func.clone(), array_args)?
} else {
UserDefinedFunction::new(func.clone(), array_args).into()
}
} else {
self.bind_builtin_scalar_function(&func_name, array_args, arg_list.variadic)?
};
Some(AggType::WrapScalar(scalar_func_expr.to_expr_proto()))
} else {
None
};
let udf = if wrapped_agg_type.is_none()
&& let Ok(schema) = self.first_valid_schema()
&& let Some(func) = schema
.get_function_by_name_inputs(&func_name, &mut args)
.cloned()
{
referred_udfs.insert(func.id);
if func.language == "sql" {
let name = format!("SQL user-defined function `{}`", func.name);
reject_syntax!(
scalar_as_agg,
format!("`AGGREGATE:` prefix is not allowed for {}", name)
);
reject_syntax!(
!arg_list.is_args_only(),
format!(
"keywords like `DISTINCT`, `ORDER BY` are not allowed in {} argument list",
name
)
);
reject_syntax!(
within_group.is_some(),
format!("`WITHIN GROUP` is not allowed in {} call", name)
);
reject_syntax!(
filter.is_some(),
format!("`FILTER` is not allowed in {} call", name)
);
reject_syntax!(
over.is_some(),
format!("`OVER` is not allowed in {} call", name)
);
return self.bind_sql_udf(func, args);
}
Some(func)
} else {
None
};
self.included_udfs.extend(referred_udfs);
let agg_type = if wrapped_agg_type.is_some() {
wrapped_agg_type
} else if let Some(ref udf) = udf
&& udf.kind.is_aggregate()
{
assert_ne!(udf.language, "sql", "SQL UDAF is not supported yet");
Some(AggType::UserDefined(udf.as_ref().into()))
} else if let Ok(agg_type) = AggType::from_str(&func_name) {
Some(agg_type)
} else {
None
};
if let Some(over) = over {
reject_syntax!(
arg_list.distinct,
"`DISTINCT` is not allowed in window function call"
);
reject_syntax!(
arg_list.variadic,
"`VARIADIC` is not allowed in window function call"
);
reject_syntax!(
!arg_list.order_by.is_empty(),
"`ORDER BY` is not allowed in window function call argument list"
);
reject_syntax!(
within_group.is_some(),
"`WITHIN GROUP` is not allowed in window function call"
);
let kind = if let Some(agg_type) = agg_type {
WindowFuncKind::Aggregate(agg_type)
} else if let Ok(kind) = WindowFuncKind::from_str(&func_name) {
kind
} else {
bail_not_implemented!(issue = 8961, "Unrecognized window function: {}", func_name);
};
return self.bind_window_function(kind, args, arg_list.ignore_nulls, filter, over);
}
reject_syntax!(
arg_list.ignore_nulls,
"`IGNORE NULLS` is not allowed in aggregate/scalar/table function call"
);
if let Some(agg_type) = agg_type {
reject_syntax!(
arg_list.variadic,
"`VARIADIC` is not allowed in aggregate function call"
);
return self.bind_aggregate_function(
agg_type,
arg_list.distinct,
args,
arg_list.order_by,
within_group,
filter,
);
}
reject_syntax!(
arg_list.distinct,
"`DISTINCT` is not allowed in scalar/table function call"
);
reject_syntax!(
!arg_list.order_by.is_empty(),
"`ORDER BY` is not allowed in scalar/table function call"
);
reject_syntax!(
within_group.is_some(),
"`WITHIN GROUP` is not allowed in scalar/table function call"
);
reject_syntax!(
filter.is_some(),
"`FILTER` is not allowed in scalar/table function call"
);
{
if func_name.eq_ignore_ascii_case("file_scan") {
reject_syntax!(
arg_list.variadic,
"`VARIADIC` is not allowed in table function call"
);
self.ensure_table_function_allowed()?;
return Ok(TableFunction::new_file_scan(args)?.into());
}
if func_name.eq("postgres_query") {
reject_syntax!(
arg_list.variadic,
"`VARIADIC` is not allowed in table function call"
);
self.ensure_table_function_allowed()?;
return Ok(TableFunction::new_postgres_query(args)
.context("postgres_query error")?
.into());
}
if func_name.eq("mysql_query") {
reject_syntax!(
arg_list.variadic,
"`VARIADIC` is not allowed in table function call"
);
self.ensure_table_function_allowed()?;
return Ok(TableFunction::new_mysql_query(args)
.context("mysql_query error")?
.into());
}
if let Some(ref udf) = udf
&& udf.kind.is_table()
{
reject_syntax!(
arg_list.variadic,
"`VARIADIC` is not allowed in table function call"
);
self.ensure_table_function_allowed()?;
return Ok(TableFunction::new_user_defined(udf.clone(), args).into());
}
if let Ok(function_type) = TableFunctionType::from_str(&func_name) {
reject_syntax!(
arg_list.variadic,
"`VARIADIC` is not allowed in table function call"
);
self.ensure_table_function_allowed()?;
return Ok(TableFunction::new(function_type, args)?.into());
}
}
if let Some(ref udf) = udf {
assert!(udf.kind.is_scalar());
reject_syntax!(
arg_list.variadic,
"`VARIADIC` is not allowed in user-defined function call"
);
return Ok(UserDefinedFunction::new(udf.clone(), args).into());
}
self.bind_builtin_scalar_function(&func_name, args, arg_list.variadic)
}
fn bind_array_transform(&mut self, args: Vec<FunctionArg>) -> Result<ExprImpl> {
let [array, lambda] = <[FunctionArg; 2]>::try_from(args).map_err(|args| -> RwError {
ErrorCode::BindError(format!(
"`array_transform` expect two inputs `array` and `lambda`, but {} were given",
args.len()
))
.into()
})?;
let bound_array = self.bind_function_arg(array)?;
let [bound_array] = <[ExprImpl; 1]>::try_from(bound_array).map_err(|bound_array| -> RwError {
ErrorCode::BindError(format!("The `array` argument for `array_transform` should be bound to one argument, but {} were got", bound_array.len()))
.into()
})?;
let inner_ty = match bound_array.return_type() {
DataType::List(ty) => *ty,
real_type => {
return Err(ErrorCode::BindError(format!(
"The `array` argument for `array_transform` should be an array, but {} were got",
real_type
))
.into())
}
};
let ast::FunctionArgExpr::Expr(ast::Expr::LambdaFunction {
args: lambda_args,
body: lambda_body,
}) = lambda.get_expr()
else {
return Err(ErrorCode::BindError(
"The `lambda` argument for `array_transform` should be a lambda function"
.to_string(),
)
.into());
};
let [lambda_arg] = <[Ident; 1]>::try_from(lambda_args).map_err(|args| -> RwError {
ErrorCode::BindError(format!(
"The `lambda` argument for `array_transform` should be a lambda function with one argument, but {} were given",
args.len()
))
.into()
})?;
let bound_lambda = self.bind_unary_lambda_function(inner_ty, lambda_arg, *lambda_body)?;
let lambda_ret_type = bound_lambda.return_type();
let transform_ret_type = DataType::List(Box::new(lambda_ret_type));
Ok(ExprImpl::FunctionCallWithLambda(Box::new(
FunctionCallWithLambda::new_unchecked(
ExprType::ArrayTransform,
vec![bound_array],
bound_lambda,
transform_ret_type,
),
)))
}
fn bind_unary_lambda_function(
&mut self,
input_ty: DataType,
arg: Ident,
body: ast::Expr,
) -> Result<ExprImpl> {
let lambda_args = HashMap::from([(arg.real_value(), (0usize, input_ty))]);
let orig_lambda_args = self.context.lambda_args.replace(lambda_args);
let body = self.bind_expr_inner(body)?;
self.context.lambda_args = orig_lambda_args;
Ok(body)
}
fn ensure_table_function_allowed(&self) -> Result<()> {
if let Some(clause) = self.context.clause {
match clause {
Clause::JoinOn
| Clause::Where
| Clause::Having
| Clause::Filter
| Clause::Values
| Clause::Insert
| Clause::GeneratedColumn => {
return Err(ErrorCode::InvalidInputSyntax(format!(
"table functions are not allowed in {}",
clause
))
.into());
}
Clause::GroupBy | Clause::From => {}
}
}
Ok(())
}
fn bind_sql_udf(
&mut self,
func: Arc<FunctionCatalog>,
args: Vec<ExprImpl>,
) -> Result<ExprImpl> {
if func.body.is_none() {
return Err(
ErrorCode::InvalidInputSyntax("`body` must exist for sql udf".to_string()).into(),
);
}
let parse_result =
risingwave_sqlparser::parser::Parser::parse_sql(func.body.as_ref().unwrap().as_str());
if let Err(ParserError::ParserError(err)) | Err(ParserError::TokenizerError(err)) =
parse_result
{
return Err(ErrorCode::InvalidInputSyntax(err).into());
}
debug_assert!(parse_result.is_ok());
let ast = parse_result.unwrap();
let stashed_udf_context = self.udf_context.get_context();
let mut udf_context = HashMap::new();
for (i, arg) in args.into_iter().enumerate() {
if func.arg_names[i].is_empty() {
udf_context.insert(format!("${}", i + 1), arg);
} else {
udf_context.insert(func.arg_names[i].clone(), arg);
}
}
self.udf_context.update_context(udf_context);
if self.udf_context.global_count() >= SQL_UDF_MAX_CALLING_DEPTH {
return Err(ErrorCode::BindError(format!(
"function {} calling stack depth limit exceeded",
func.name
))
.into());
} else {
self.udf_context.incr_global_count();
}
if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
let bind_result = self.bind_expr(expr);
self.udf_context.decr_global_count();
self.udf_context.update_context(stashed_udf_context);
return bind_result;
}
Err(ErrorCode::InvalidInputSyntax(
"failed to parse the input query and extract the udf expression,
please recheck the syntax"
.to_string(),
)
.into())
}
pub(in crate::binder) fn bind_function_expr_arg(
&mut self,
arg_expr: FunctionArgExpr,
) -> Result<Vec<ExprImpl>> {
match arg_expr {
FunctionArgExpr::Expr(expr) => Ok(vec![self.bind_expr_inner(expr)?]),
FunctionArgExpr::QualifiedWildcard(_, _)
| FunctionArgExpr::ExprQualifiedWildcard(_, _) => Err(ErrorCode::InvalidInputSyntax(
format!("unexpected wildcard {}", arg_expr),
)
.into()),
FunctionArgExpr::Wildcard(None) => Ok(vec![]),
FunctionArgExpr::Wildcard(Some(_)) => unreachable!(),
}
}
pub(in crate::binder) fn bind_function_arg(
&mut self,
arg: FunctionArg,
) -> Result<Vec<ExprImpl>> {
match arg {
FunctionArg::Unnamed(expr) => self.bind_function_expr_arg(expr),
FunctionArg::Named { .. } => todo!(),
}
}
}