1use std::collections::HashMap;
16use std::sync::LazyLock;
17
18use bk_tree::{BKTree, metrics};
19use itertools::Itertools;
20use risingwave_common::session_config::USER_NAME_WILD_CARD;
21use risingwave_common::types::{DataType, ListValue, ScalarImpl, Timestamptz};
22use risingwave_common::{bail_not_implemented, current_cluster_version, no_function};
23use thiserror_ext::AsReport;
24
25use crate::Binder;
26use crate::binder::Clause;
27use crate::error::{ErrorCode, Result};
28use crate::expr::{CastContext, Expr, ExprImpl, ExprType, FunctionCall, Literal, Now};
29
30impl Binder {
31 pub(super) fn bind_builtin_scalar_function(
32 &mut self,
33 function_name: &str,
34 inputs: Vec<ExprImpl>,
35 variadic: bool,
36 ) -> Result<ExprImpl> {
37 type Inputs = Vec<ExprImpl>;
38
39 type Handle = Box<dyn Fn(&mut Binder, Inputs) -> Result<ExprImpl> + Sync + Send>;
40
41 fn rewrite(r#type: ExprType, rewriter: fn(Inputs) -> Result<Inputs>) -> Handle {
42 Box::new(move |_binder, mut inputs| {
43 inputs = (rewriter)(inputs)?;
44 Ok(FunctionCall::new(r#type, inputs)?.into())
45 })
46 }
47
48 fn raw_call(r#type: ExprType) -> Handle {
49 rewrite(r#type, Ok)
50 }
51
52 fn guard_by_len<const E: usize>(
53 handle: impl Fn(&mut Binder, [ExprImpl; E]) -> Result<ExprImpl> + Sync + Send + 'static,
54 ) -> Handle {
55 Box::new(move |binder, inputs| {
56 let input_len = inputs.len();
57 let Ok(inputs) = inputs.try_into() else {
58 return Err(ErrorCode::ExprError(
59 format!("unexpected arguments number {}, expect {}", input_len, E).into(),
60 )
61 .into());
62 };
63 handle(binder, inputs)
64 })
65 }
66
67 fn raw<F: Fn(&mut Binder, Inputs) -> Result<ExprImpl> + Sync + Send + 'static>(
68 f: F,
69 ) -> Handle {
70 Box::new(f)
71 }
72
73 fn dispatch_by_len(mapping: Vec<(usize, Handle)>) -> Handle {
74 Box::new(move |binder, inputs| {
75 for (len, handle) in &mapping {
76 if inputs.len() == *len {
77 return handle(binder, inputs);
78 }
79 }
80 Err(ErrorCode::ExprError("unexpected arguments number".into()).into())
81 })
82 }
83
84 fn raw_literal(literal: ExprImpl) -> Handle {
85 Box::new(move |_binder, _inputs| Ok(literal.clone()))
86 }
87
88 fn now() -> Handle {
89 guard_by_len(move |binder, []| {
90 binder.ensure_now_function_allowed()?;
91 Ok(Now.into())
94 })
95 }
96
97 fn pi() -> Handle {
98 raw_literal(ExprImpl::literal_f64(std::f64::consts::PI))
99 }
100
101 fn proctime() -> Handle {
102 Box::new(move |binder, inputs| {
103 binder.ensure_proctime_function_allowed()?;
104 raw_call(ExprType::Proctime)(binder, inputs)
105 })
106 }
107
108 fn session_user() -> Handle {
110 guard_by_len(|binder, []| {
111 Ok(ExprImpl::literal_varchar(
112 binder.auth_context.user_name.clone(),
113 ))
114 })
115 }
116
117 fn current_user() -> Handle {
121 guard_by_len(|binder, []| {
122 Ok(ExprImpl::literal_varchar(
123 binder.auth_context.user_name.clone(),
124 ))
125 })
126 }
127
128 fn current_database() -> Handle {
131 guard_by_len(|binder, []| Ok(ExprImpl::literal_varchar(binder.db_name.clone())))
132 }
133
134 static HANDLES: LazyLock<HashMap<&'static str, Handle>> = LazyLock::new(|| {
138 [
139 (
140 "booleq",
141 rewrite(ExprType::Equal, rewrite_two_bool_inputs),
142 ),
143 (
144 "boolne",
145 rewrite(ExprType::NotEqual, rewrite_two_bool_inputs),
146 ),
147 ("coalesce", rewrite(ExprType::Coalesce, |inputs| {
148 if inputs.iter().any(ExprImpl::has_table_function) {
149 return Err(ErrorCode::BindError("table functions are not allowed in COALESCE".into()).into());
150 }
151 Ok(inputs)
152 })),
153 (
154 "nullif",
155 rewrite(ExprType::Case, rewrite_nullif_to_case_when),
156 ),
157 (
158 "round",
159 dispatch_by_len(vec![
160 (2, raw_call(ExprType::RoundDigit)),
161 (1, raw_call(ExprType::Round)),
162 ]),
163 ),
164 ("pow", raw_call(ExprType::Pow)),
165 ("power", raw_call(ExprType::Pow)),
167 ("ceil", raw_call(ExprType::Ceil)),
168 ("ceiling", raw_call(ExprType::Ceil)),
169 ("floor", raw_call(ExprType::Floor)),
170 ("trunc", raw_call(ExprType::Trunc)),
171 ("abs", raw_call(ExprType::Abs)),
172 ("exp", raw_call(ExprType::Exp)),
173 ("ln", raw_call(ExprType::Ln)),
174 ("log", raw_call(ExprType::Log10)),
175 ("log10", raw_call(ExprType::Log10)),
176 ("mod", raw_call(ExprType::Modulus)),
177 ("sin", raw_call(ExprType::Sin)),
178 ("cos", raw_call(ExprType::Cos)),
179 ("tan", raw_call(ExprType::Tan)),
180 ("cot", raw_call(ExprType::Cot)),
181 ("asin", raw_call(ExprType::Asin)),
182 ("acos", raw_call(ExprType::Acos)),
183 ("atan", raw_call(ExprType::Atan)),
184 ("atan2", raw_call(ExprType::Atan2)),
185 ("sind", raw_call(ExprType::Sind)),
186 ("cosd", raw_call(ExprType::Cosd)),
187 ("cotd", raw_call(ExprType::Cotd)),
188 ("tand", raw_call(ExprType::Tand)),
189 ("sinh", raw_call(ExprType::Sinh)),
190 ("cosh", raw_call(ExprType::Cosh)),
191 ("tanh", raw_call(ExprType::Tanh)),
192 ("coth", raw_call(ExprType::Coth)),
193 ("asinh", raw_call(ExprType::Asinh)),
194 ("acosh", raw_call(ExprType::Acosh)),
195 ("atanh", raw_call(ExprType::Atanh)),
196 ("asind", raw_call(ExprType::Asind)),
197 ("acosd", raw_call(ExprType::Acosd)),
198 ("atand", raw_call(ExprType::Atand)),
199 ("atan2d", raw_call(ExprType::Atan2d)),
200 ("degrees", raw_call(ExprType::Degrees)),
201 ("radians", raw_call(ExprType::Radians)),
202 ("sqrt", raw_call(ExprType::Sqrt)),
203 ("cbrt", raw_call(ExprType::Cbrt)),
204 ("sign", raw_call(ExprType::Sign)),
205 ("scale", raw_call(ExprType::Scale)),
206 ("min_scale", raw_call(ExprType::MinScale)),
207 ("trim_scale", raw_call(ExprType::TrimScale)),
208 ("gamma", raw_call(ExprType::Gamma)),
209 ("lgamma", raw_call(ExprType::Lgamma)),
210 (
212 "to_timestamp",
213 dispatch_by_len(vec![
214 (1, raw_call(ExprType::SecToTimestamptz)),
215 (2, raw_call(ExprType::CharToTimestamptz)),
216 ]),
217 ),
218 ("date_trunc", raw_call(ExprType::DateTrunc)),
219 ("date_bin", raw_call(ExprType::DateBin)),
220 ("date_part", raw_call(ExprType::DatePart)),
221 ("make_date", raw_call(ExprType::MakeDate)),
222 ("make_time", raw_call(ExprType::MakeTime)),
223 ("make_timestamp", raw_call(ExprType::MakeTimestamp)),
224 ("make_timestamptz", raw_call(ExprType::MakeTimestamptz)),
225 ("timezone", guard_by_len(|_binder, [arg0, arg1]| {
226 Ok(FunctionCall::new(ExprType::AtTimeZone, vec![arg1, arg0])?.into())
228 })),
229 ("to_date", raw_call(ExprType::CharToDate)),
230 ("substr", raw_call(ExprType::Substr)),
232 ("length", raw_call(ExprType::Length)),
233 ("upper", raw_call(ExprType::Upper)),
234 ("lower", raw_call(ExprType::Lower)),
235 ("trim", raw_call(ExprType::Trim)),
236 ("replace", raw_call(ExprType::Replace)),
237 ("overlay", raw_call(ExprType::Overlay)),
238 ("btrim", raw_call(ExprType::Trim)),
239 ("ltrim", raw_call(ExprType::Ltrim)),
240 ("rtrim", raw_call(ExprType::Rtrim)),
241 ("md5", raw_call(ExprType::Md5)),
242 ("to_char", raw_call(ExprType::ToChar)),
243 (
244 "concat",
245 rewrite(ExprType::ConcatWs, rewrite_concat_to_concat_ws),
246 ),
247 ("concat_ws", raw_call(ExprType::ConcatWs)),
248 ("format", raw_call(ExprType::Format)),
249 ("translate", raw_call(ExprType::Translate)),
250 ("split_part", raw_call(ExprType::SplitPart)),
251 ("char_length", raw_call(ExprType::CharLength)),
252 ("character_length", raw_call(ExprType::CharLength)),
253 ("repeat", raw_call(ExprType::Repeat)),
254 ("ascii", raw_call(ExprType::Ascii)),
255 ("octet_length", raw_call(ExprType::OctetLength)),
256 ("bit_length", raw_call(ExprType::BitLength)),
257 ("regexp_match", raw_call(ExprType::RegexpMatch)),
258 ("regexp_replace", raw_call(ExprType::RegexpReplace)),
259 ("regexp_count", raw_call(ExprType::RegexpCount)),
260 ("regexp_split_to_array", raw_call(ExprType::RegexpSplitToArray)),
261 ("chr", raw_call(ExprType::Chr)),
262 ("starts_with", raw_call(ExprType::StartsWith)),
263 ("initcap", raw_call(ExprType::Initcap)),
264 ("lpad", raw_call(ExprType::Lpad)),
265 ("rpad", raw_call(ExprType::Rpad)),
266 ("reverse", raw_call(ExprType::Reverse)),
267 ("strpos", raw_call(ExprType::Position)),
268 ("to_ascii", raw_call(ExprType::ToAscii)),
269 ("to_hex", raw_call(ExprType::ToHex)),
270 ("quote_ident", raw_call(ExprType::QuoteIdent)),
271 ("quote_literal", guard_by_len(|_binder, [mut input]| {
272 if input.return_type() != DataType::Varchar {
273 FunctionCall::cast_mut(&mut input, &DataType::Varchar, CastContext::Explicit)?;
276 }
277 Ok(FunctionCall::new_unchecked(ExprType::QuoteLiteral, vec![input], DataType::Varchar).into())
278 })),
279 ("quote_nullable", guard_by_len(|_binder, [mut input]| {
280 if input.return_type() != DataType::Varchar {
281 FunctionCall::cast_mut(&mut input, &DataType::Varchar, CastContext::Explicit)?;
284 }
285 Ok(FunctionCall::new_unchecked(ExprType::QuoteNullable, vec![input], DataType::Varchar).into())
286 })),
287 ("string_to_array", raw_call(ExprType::StringToArray)),
288 ("get_bit", raw_call(ExprType::GetBit)),
289 ("get_byte", raw_call(ExprType::GetByte)),
290 ("set_bit", raw_call(ExprType::SetBit)),
291 ("set_byte", raw_call(ExprType::SetByte)),
292 ("bit_count", raw_call(ExprType::BitCount)),
293 ("encode", raw_call(ExprType::Encode)),
294 ("decode", raw_call(ExprType::Decode)),
295 ("convert_from", raw_call(ExprType::ConvertFrom)),
296 ("convert_to", raw_call(ExprType::ConvertTo)),
297 ("sha1", raw_call(ExprType::Sha1)),
298 ("sha224", raw_call(ExprType::Sha224)),
299 ("sha256", raw_call(ExprType::Sha256)),
300 ("sha384", raw_call(ExprType::Sha384)),
301 ("sha512", raw_call(ExprType::Sha512)),
302 ("encrypt", raw_call(ExprType::Encrypt)),
303 ("decrypt", raw_call(ExprType::Decrypt)),
304 ("hmac", raw_call(ExprType::Hmac)),
305 ("crc32", raw_call(ExprType::Crc32)),
306 ("crc32c", raw_call(ExprType::Crc32c)),
307 ("secure_compare", raw_call(ExprType::SecureCompare)),
308 ("left", raw_call(ExprType::Left)),
309 ("right", raw_call(ExprType::Right)),
310 ("inet_aton", raw_call(ExprType::InetAton)),
311 ("inet_ntoa", raw_call(ExprType::InetNtoa)),
312 ("int8send", raw_call(ExprType::PgwireSend)),
313 ("int8recv", guard_by_len(|_binder, [mut input]| {
314 let hint = if !input.is_untyped() && input.return_type() == DataType::Varchar {
316 " Consider `decode` or cast."
317 } else {
318 ""
319 };
320 input.cast_implicit_mut(&DataType::Bytea).map_err(|e| {
321 ErrorCode::BindError(format!("{} in `recv`.{hint}", e.as_report()))
322 })?;
323 Ok(FunctionCall::new_unchecked(ExprType::PgwireRecv, vec![input], DataType::Int64).into())
324 })),
325 ("array_cat", raw_call(ExprType::ArrayCat)),
327 ("array_append", raw_call(ExprType::ArrayAppend)),
328 ("array_join", raw_call(ExprType::ArrayToString)),
329 ("array_prepend", raw_call(ExprType::ArrayPrepend)),
330 ("array_to_string", raw_call(ExprType::ArrayToString)),
331 ("array_distinct", raw_call(ExprType::ArrayDistinct)),
332 ("array_min", raw_call(ExprType::ArrayMin)),
333 ("array_sort", raw_call(ExprType::ArraySort)),
334 ("array_length", raw_call(ExprType::ArrayLength)),
335 ("cardinality", raw_call(ExprType::Cardinality)),
336 ("array_remove", raw_call(ExprType::ArrayRemove)),
337 ("array_replace", raw_call(ExprType::ArrayReplace)),
338 ("array_reverse", raw_call(ExprType::ArrayReverse)),
339 ("array_max", raw_call(ExprType::ArrayMax)),
340 ("array_sum", raw_call(ExprType::ArraySum)),
341 ("array_position", raw_call(ExprType::ArrayPosition)),
342 ("array_positions", raw_call(ExprType::ArrayPositions)),
343 ("array_contains", raw_call(ExprType::ArrayContains)),
344 ("arraycontains", raw_call(ExprType::ArrayContains)),
345 ("array_contained", raw_call(ExprType::ArrayContained)),
346 ("arraycontained", raw_call(ExprType::ArrayContained)),
347 ("array_overlaps", raw_call(ExprType::ArrayOverlaps)),
348 ("array_is_overlap", raw_call(ExprType::ArrayOverlaps)),
349 ("array_is_intersect", raw_call(ExprType::ArrayOverlaps)),
350 ("array_flatten", guard_by_len(|_binder, [input]| {
351 input.ensure_array_type().map_err(|_| ErrorCode::BindError("array_flatten expects `any[][]` input".into()))?;
352 let return_type = input.return_type().into_list_elem();
353 if !return_type.is_array() {
354 return Err(ErrorCode::BindError("array_flatten expects `any[][]` input".into()).into());
355 }
356 Ok(FunctionCall::new_unchecked(ExprType::ArrayFlatten, vec![input], return_type).into())
357 })),
358 ("trim_array", raw_call(ExprType::TrimArray)),
359 (
360 "array_ndims",
361 guard_by_len(|_binder, [input]| {
362 input.ensure_array_type()?;
363
364 let n = input.return_type().array_ndims()
365 .try_into().map_err(|_| ErrorCode::BindError("array_ndims integer overflow".into()))?;
366 Ok(ExprImpl::literal_int(n))
367 }),
368 ),
369 (
370 "array_lower",
371 guard_by_len(|binder, [arg0, arg1]| {
372 let ndims_expr = binder.bind_builtin_scalar_function("array_ndims", vec![arg0], false)?;
374 let arg1 = arg1.cast_implicit(&DataType::Int32)?;
375
376 FunctionCall::new(
377 ExprType::Case,
378 vec![
379 FunctionCall::new(
380 ExprType::And,
381 vec![
382 FunctionCall::new(ExprType::LessThan, vec![ExprImpl::literal_int(0), arg1.clone()])?.into(),
383 FunctionCall::new(ExprType::LessThanOrEqual, vec![arg1, ndims_expr])?.into(),
384 ],
385 )?.into(),
386 ExprImpl::literal_int(1),
387 ],
388 ).map(Into::into)
389 }),
390 ),
391 ("array_upper", raw_call(ExprType::ArrayLength)), ("array_dims", raw_call(ExprType::ArrayDims)),
393 ("hex_to_int256", raw_call(ExprType::HexToInt256)),
395 ("jsonb_object_field", raw_call(ExprType::JsonbAccess)),
397 ("jsonb_array_element", raw_call(ExprType::JsonbAccess)),
398 ("jsonb_object_field_text", raw_call(ExprType::JsonbAccessStr)),
399 ("jsonb_array_element_text", raw_call(ExprType::JsonbAccessStr)),
400 ("jsonb_extract_path", raw_call(ExprType::JsonbExtractPath)),
401 ("jsonb_extract_path_text", raw_call(ExprType::JsonbExtractPathText)),
402 ("jsonb_typeof", raw_call(ExprType::JsonbTypeof)),
403 ("jsonb_array_length", raw_call(ExprType::JsonbArrayLength)),
404 ("jsonb_concat", raw_call(ExprType::JsonbConcat)),
405 ("jsonb_object", raw_call(ExprType::JsonbObject)),
406 ("jsonb_pretty", raw_call(ExprType::JsonbPretty)),
407 ("jsonb_contains", raw_call(ExprType::JsonbContains)),
408 ("jsonb_contained", raw_call(ExprType::JsonbContained)),
409 ("jsonb_exists", raw_call(ExprType::JsonbExists)),
410 ("jsonb_exists_any", raw_call(ExprType::JsonbExistsAny)),
411 ("jsonb_exists_all", raw_call(ExprType::JsonbExistsAll)),
412 ("jsonb_delete", raw_call(ExprType::Subtract)),
413 ("jsonb_delete_path", raw_call(ExprType::JsonbDeletePath)),
414 ("jsonb_strip_nulls", raw_call(ExprType::JsonbStripNulls)),
415 ("to_jsonb", raw_call(ExprType::ToJsonb)),
416 ("jsonb_build_array", raw_call(ExprType::JsonbBuildArray)),
417 ("jsonb_build_object", raw_call(ExprType::JsonbBuildObject)),
418 ("jsonb_populate_record", raw_call(ExprType::JsonbPopulateRecord)),
419 ("jsonb_path_match", raw_call(ExprType::JsonbPathMatch)),
420 ("jsonb_path_exists", raw_call(ExprType::JsonbPathExists)),
421 ("jsonb_path_query_array", raw_call(ExprType::JsonbPathQueryArray)),
422 ("jsonb_path_query_first", raw_call(ExprType::JsonbPathQueryFirst)),
423 ("jsonb_set", raw_call(ExprType::JsonbSet)),
424 ("jsonb_populate_map", raw_call(ExprType::JsonbPopulateMap)),
425 ("jsonb_to_array", raw_call(ExprType::JsonbToArray)),
426 ("map_from_entries", raw_call(ExprType::MapFromEntries)),
428 ("map_access", raw_call(ExprType::MapAccess)),
429 ("map_keys", raw_call(ExprType::MapKeys)),
430 ("map_values", raw_call(ExprType::MapValues)),
431 ("map_entries", raw_call(ExprType::MapEntries)),
432 ("map_from_key_values", raw_call(ExprType::MapFromKeyValues)),
433 ("map_cat", raw_call(ExprType::MapCat)),
434 ("map_contains", raw_call(ExprType::MapContains)),
435 ("map_delete", raw_call(ExprType::MapDelete)),
436 ("map_insert", raw_call(ExprType::MapInsert)),
437 ("map_length", raw_call(ExprType::MapLength)),
438 ("l2_distance", raw_call(ExprType::L2Distance)),
440 ("cosine_distance", raw_call(ExprType::CosineDistance)),
441 ("l1_distance", raw_call(ExprType::L1Distance)),
442 ("inner_product", raw_call(ExprType::InnerProduct)),
443 ("vector_norm", raw_call(ExprType::L2Norm)),
444 ("l2_normalize", raw_call(ExprType::L2Normalize)),
445 ("subvector", guard_by_len(|_, [vector_expr, start_expr, len_expr]| {
446 let dimensions = if let DataType::Vector(length) = vector_expr.return_type() {
447 length as i32
448 } else {
449 return Err(ErrorCode::BindError("subvector expects `vector(dim)` input".into()).into());
450 };
451 let start = start_expr
452 .try_fold_const()
453 .transpose()?
454 .and_then(|datum| match datum {
455 Some(ScalarImpl::Int32(v)) => Some(v),
456 _ => None,
457 })
458 .ok_or_else(|| ErrorCode::ExprError("`start` must be an Int32 constant".into()))?;
459
460 let len = len_expr
461 .try_fold_const()
462 .transpose()?
463 .and_then(|datum| match datum {
464 Some(ScalarImpl::Int32(v)) => Some(v),
465 _ => None,
466 })
467 .ok_or_else(|| ErrorCode::ExprError("`count` must be an Int32 constant".into()))?;
468 if len < 1 || len > DataType::VEC_MAX_SIZE as i32 {
469 return Err(ErrorCode::InvalidParameterValue(format!("Invalid vector size: expected 1..={}, got {}", DataType::VEC_MAX_SIZE, len)).into());
470 }
471
472 let end = start + len - 1;
473
474 if start < 1 || end > dimensions {
475 return Err(ErrorCode::InvalidParameterValue(format!(
476 "vector slice range out of bounds: start={}, end={}, valid range is [1, {}]",
477 start,
478 end,
479 dimensions
480 )).into());
481 }
482
483 Ok(FunctionCall::new_unchecked(ExprType::Subvector, vec![vector_expr, start_expr, len_expr], DataType::Vector(len as usize)).into())
484 })),
485 ("pi", pi()),
487 ("greatest", raw_call(ExprType::Greatest)),
489 ("least", raw_call(ExprType::Least)),
490 (
492 "pg_typeof",
493 guard_by_len(|_binder, [input]| {
494 let v = match input.is_untyped() {
495 true => "unknown".into(),
496 false => input.return_type().to_string(),
497 };
498 Ok(ExprImpl::literal_varchar(v))
499 }),
500 ),
501 ("current_catalog", current_database()),
502 ("current_database", current_database()),
503 ("current_schema", guard_by_len(|binder, []| {
504 Ok(binder
505 .first_valid_schema()
506 .map(|schema| ExprImpl::literal_varchar(schema.name()))
507 .unwrap_or_else(|_| ExprImpl::literal_null(DataType::Varchar)))
508 })),
509 ("current_schemas", raw(|binder, mut inputs| {
510 let no_match_err = ErrorCode::ExprError(
511 "No function matches the given name and argument types. You might need to add explicit type casts.".into()
512 );
513 if inputs.len() != 1 {
514 return Err(no_match_err.into());
515 }
516 let input = inputs
517 .pop()
518 .unwrap()
519 .enforce_bool_clause("current_schemas")
520 .map_err(|_| no_match_err)?;
521
522 let ExprImpl::Literal(literal) = &input else {
523 bail_not_implemented!("Only boolean literals are supported in `current_schemas`.");
524 };
525
526 let Some(bool) = literal.get_data().as_ref().map(|bool| bool.clone().into_bool()) else {
527 return Ok(ExprImpl::literal_null(DataType::Varchar.list()));
528 };
529
530 let paths = if bool {
531 binder.search_path.path()
532 } else {
533 binder.search_path.real_path()
534 };
535
536 let mut schema_names = vec![];
537 for path in paths {
538 let mut schema_name = path;
539 if schema_name == USER_NAME_WILD_CARD {
540 schema_name = &binder.auth_context.user_name;
541 }
542
543 if binder
544 .catalog
545 .get_schema_by_name(&binder.db_name, schema_name)
546 .is_ok()
547 {
548 schema_names.push(schema_name.as_str());
549 }
550 }
551
552 Ok(ExprImpl::literal_list(
553 ListValue::from_iter(schema_names),
554 DataType::Varchar,
555 ))
556 })),
557 ("session_user", session_user()),
558 ("current_role", current_user()),
559 ("current_user", current_user()),
560 ("user", current_user()),
561 ("pg_get_userbyid", raw_call(ExprType::PgGetUserbyid)),
562 ("pg_get_indexdef", raw_call(ExprType::PgGetIndexdef)),
563 ("pg_get_viewdef", raw_call(ExprType::PgGetViewdef)),
564 ("pg_index_column_has_property", raw_call(ExprType::PgIndexColumnHasProperty)),
565 ("pg_relation_size", raw(|_binder, mut inputs| {
566 if inputs.is_empty() {
567 return Err(ErrorCode::ExprError(
568 "function pg_relation_size() does not exist".into(),
569 )
570 .into());
571 }
572 inputs[0].cast_to_regclass_mut()?;
573 Ok(FunctionCall::new(ExprType::PgRelationSize, inputs)?.into())
574 })),
575 ("pg_get_serial_sequence", raw_literal(ExprImpl::literal_null(DataType::Varchar))),
576 ("pg_table_size", guard_by_len(|_binder, [mut input]| {
577 input.cast_to_regclass_mut()?;
578 Ok(FunctionCall::new(ExprType::PgRelationSize, vec![input])?.into())
579 })),
580 ("pg_indexes_size", guard_by_len(|_binder, [mut input]| {
581 input.cast_to_regclass_mut()?;
582 Ok(FunctionCall::new(ExprType::PgIndexesSize, vec![input])?.into())
583 })),
584 ("pg_get_expr", raw(|_binder, inputs| {
585 if inputs.len() == 2 || inputs.len() == 3 {
586 Ok(ExprImpl::literal_varchar("".into()))
588 } else {
589 Err(ErrorCode::ExprError(
590 "Too many/few arguments for pg_catalog.pg_get_expr()".into(),
591 )
592 .into())
593 }
594 })),
595 ("pg_my_temp_schema", guard_by_len(|_binder, []| {
596 Ok(ExprImpl::literal_int(
598 0,
600 ))
601 })),
602 ("current_setting", guard_by_len(|binder, [input]| {
603 let input = if let ExprImpl::Literal(literal) = &input &&
604 let Some(ScalarImpl::Utf8(input)) = literal.get_data()
605 {
606 input
607 } else {
608 return Err(ErrorCode::ExprError(
609 "Only literal is supported in `setting_name`.".into(),
610 )
611 .into());
612 };
613 let session_config = binder.session_config.read();
614 Ok(ExprImpl::literal_varchar(session_config.get(input.as_ref())?))
615 })),
616 ("set_config", guard_by_len(|binder, [arg0, arg1, arg2]| {
617 let setting_name = if let ExprImpl::Literal(literal) = &arg0 && let Some(ScalarImpl::Utf8(input)) = literal.get_data() {
618 input
619 } else {
620 return Err(ErrorCode::ExprError(
621 "Only string literal is supported in `setting_name`.".into(),
622 )
623 .into());
624 };
625
626 let new_value = if let ExprImpl::Literal(literal) = &arg1 && let Some(ScalarImpl::Utf8(input)) = literal.get_data() {
627 input
628 } else {
629 return Err(ErrorCode::ExprError(
630 "Only string literal is supported in `setting_name`.".into(),
631 )
632 .into());
633 };
634
635 let is_local = if let ExprImpl::Literal(literal) = &arg2 && let Some(ScalarImpl::Bool(input)) = literal.get_data() {
636 input
637 } else {
638 return Err(ErrorCode::ExprError(
639 "Only bool literal is supported in `is_local`.".into(),
640 )
641 .into());
642 };
643
644 if *is_local {
645 return Err(ErrorCode::ExprError(
646 "`is_local = true` is not supported now.".into(),
647 )
648 .into());
649 }
650
651 let mut session_config = binder.session_config.write();
652
653 session_config.set(setting_name, new_value.to_string(), &mut ())?;
655
656 Ok(ExprImpl::literal_varchar(new_value.to_string()))
657 })),
658 ("format_type", raw_call(ExprType::FormatType)),
659 ("pg_table_is_visible", raw_call(ExprType::PgTableIsVisible)),
660 ("pg_type_is_visible", raw_literal(ExprImpl::literal_bool(true))),
661 ("pg_get_constraintdef", raw_literal(ExprImpl::literal_null(DataType::Varchar))),
662 ("pg_get_partkeydef", raw_literal(ExprImpl::literal_null(DataType::Varchar))),
663 ("pg_encoding_to_char", raw_literal(ExprImpl::literal_varchar("UTF8".into()))),
664 ("has_database_privilege", raw(|binder, mut inputs| {
665 if inputs.len() == 2 {
666 inputs.insert(0, ExprImpl::literal_varchar(binder.auth_context.user_name.clone()));
667 }
668 if inputs.len() == 3 {
669 Ok(FunctionCall::new(ExprType::HasDatabasePrivilege, inputs)?.into())
670 } else {
671 Err(ErrorCode::ExprError(
672 "Too many/few arguments for pg_catalog.has_database_privilege()".into(),
673 )
674 .into())
675 }
676 })),
677 ("has_table_privilege", raw(|binder, mut inputs| {
678 if inputs.len() == 2 {
679 inputs.insert(0, ExprImpl::literal_varchar(binder.auth_context.user_name.clone()));
680 }
681 if inputs.len() == 3 {
682 if inputs[1].return_type() == DataType::Varchar {
683 inputs[1].cast_to_regclass_mut()?;
684 }
685 Ok(FunctionCall::new(ExprType::HasTablePrivilege, inputs)?.into())
686 } else {
687 Err(ErrorCode::ExprError(
688 "Too many/few arguments for pg_catalog.has_table_privilege()".into(),
689 )
690 .into())
691 }
692 })),
693 ("has_any_column_privilege", raw(|binder, mut inputs| {
694 if inputs.len() == 2 {
695 inputs.insert(0, ExprImpl::literal_varchar(binder.auth_context.user_name.clone()));
696 }
697 if inputs.len() == 3 {
698 if inputs[1].return_type() == DataType::Varchar {
699 inputs[1].cast_to_regclass_mut()?;
700 }
701 Ok(FunctionCall::new(ExprType::HasAnyColumnPrivilege, inputs)?.into())
702 } else {
703 Err(ErrorCode::ExprError(
704 "Too many/few arguments for pg_catalog.has_any_column_privilege()".into(),
705 )
706 .into())
707 }
708 })),
709 ("has_schema_privilege", raw(|binder, mut inputs| {
710 if inputs.len() == 2 {
711 inputs.insert(0, ExprImpl::literal_varchar(binder.auth_context.user_name.clone()));
712 }
713 if inputs.len() == 3 {
714 Ok(FunctionCall::new(ExprType::HasSchemaPrivilege, inputs)?.into())
715 } else {
716 Err(ErrorCode::ExprError(
717 "Too many/few arguments for pg_catalog.has_schema_privilege()".into(),
718 )
719 .into())
720 }
721 })),
722 ("has_function_privilege", raw(|binder, mut inputs| {
723 if inputs.len() == 2 {
724 inputs.insert(0, ExprImpl::literal_varchar(binder.auth_context.user_name.clone()));
725 }
726 if inputs.len() == 3 {
727 Ok(FunctionCall::new(ExprType::HasFunctionPrivilege, inputs)?.into())
728 } else {
729 Err(ErrorCode::ExprError(
730 "Too many/few arguments for pg_catalog.has_function_privilege()".into(),
731 )
732 .into())
733 }
734 })),
735 ("pg_stat_get_numscans", raw_literal(ExprImpl::literal_bigint(0))),
736 ("pg_backend_pid", raw(|binder, _inputs| {
737 Ok(ExprImpl::literal_int(binder.session_id.0))
739 })),
740 ("pg_cancel_backend", guard_by_len(|_binder, [_input]| {
741 Ok(ExprImpl::literal_bool(false))
743 })),
744 ("pg_terminate_backend", guard_by_len(|_binder, [_input]| {
745 Ok(ExprImpl::literal_bool(false))
748 })),
749 ("pg_tablespace_location", guard_by_len(|_binder, [_input]| {
750 Ok(ExprImpl::literal_null(DataType::Varchar))
751 })),
752 ("pg_postmaster_start_time", guard_by_len(|_binder, []| {
753 let server_start_time = risingwave_variables::get_server_start_time();
754 let datum = server_start_time.map(Timestamptz::from).map(ScalarImpl::from);
755 let literal = Literal::new(datum, DataType::Timestamptz);
756 Ok(literal.into())
757 })),
758 ("col_description", raw_call(ExprType::ColDescription)),
762 ("obj_description", raw_literal(ExprImpl::literal_varchar("".to_owned()))),
763 ("shobj_description", raw_literal(ExprImpl::literal_varchar("".to_owned()))),
764 ("pg_is_in_recovery", raw_call(ExprType::PgIsInRecovery)),
765 ("rw_recovery_status", raw_call(ExprType::RwRecoveryStatus)),
766 ("rw_cluster_id", raw_call(ExprType::RwClusterId)),
767 ("rw_epoch_to_ts", raw_call(ExprType::RwEpochToTs)),
768 ("rw_fragment_vnodes", raw_call(ExprType::RwFragmentVnodes)),
769 ("rw_actor_vnodes", raw_call(ExprType::RwActorVnodes)),
770 ("rw_vnode", raw_call(ExprType::VnodeUser)),
772 ("rw_license", raw_call(ExprType::License)),
773 ("rw_test_paid_tier", raw_call(ExprType::TestFeature)), ("rw_test_feature", raw_call(ExprType::TestFeature)), ("version", raw_literal(ExprImpl::literal_varchar(current_cluster_version()))),
777 ("now", now()),
779 ("current_timestamp", now()),
780 ("proctime", proctime()),
781 ("pg_sleep", raw_call(ExprType::PgSleep)),
782 ("pg_sleep_for", raw_call(ExprType::PgSleepFor)),
783 ("random", raw_call(ExprType::Random)),
784 ("date", guard_by_len(|_binder, [input]| {
790 input.cast_explicit(&DataType::Date).map_err(Into::into)
791 })),
792
793 ("openai_embedding", guard_by_len(|_binder, [arg0, arg1]| {
795 if let ExprImpl::Literal(config) = &arg0 && let Some(ScalarImpl::Jsonb(_config)) = config.get_data() {
797 Ok(FunctionCall::new(ExprType::OpenaiEmbedding, vec![arg0, arg1])?.into())
798 } else {
799 Err(ErrorCode::InvalidInputSyntax(
800 "`embedding_config` must be constant jsonb".to_owned(),
801 ).into())
802 }
803 })),
804 ]
805 .into_iter()
806 .collect()
807 });
808
809 static FUNCTIONS_BKTREE: LazyLock<BKTree<&str>> = LazyLock::new(|| {
810 let mut tree = BKTree::new(metrics::Levenshtein);
811
812 for k in HANDLES.keys() {
814 tree.add(*k);
815 }
816
817 tree
818 });
819
820 if variadic {
821 let func = match function_name {
822 "format" => ExprType::FormatVariadic,
823 "concat" => ExprType::ConcatVariadic,
824 "concat_ws" => ExprType::ConcatWsVariadic,
825 "jsonb_build_array" => ExprType::JsonbBuildArrayVariadic,
826 "jsonb_build_object" => ExprType::JsonbBuildObjectVariadic,
827 "jsonb_extract_path" => ExprType::JsonbExtractPathVariadic,
828 "jsonb_extract_path_text" => ExprType::JsonbExtractPathTextVariadic,
829 _ => {
830 return Err(ErrorCode::BindError(format!(
831 "VARIADIC argument is not allowed in function \"{}\"",
832 function_name
833 ))
834 .into());
835 }
836 };
837 return Ok(FunctionCall::new(func, inputs)?.into());
838 }
839
840 match HANDLES.get(function_name) {
842 Some(handle) => handle(self, inputs),
843 None => {
844 let allowed_distance = if function_name.len() > 3 { 2 } else { 1 };
845
846 let candidates = FUNCTIONS_BKTREE
847 .find(function_name, allowed_distance)
848 .map(|(_idx, c)| c)
849 .join(" or ");
850
851 Err(no_function!(
852 candidates = (!candidates.is_empty()).then_some(candidates),
853 "{}({})",
854 function_name,
855 inputs.iter().map(|e| e.return_type()).join(", ")
856 )
857 .into())
858 }
859 }
860 }
861
862 fn ensure_now_function_allowed(&self) -> Result<()> {
863 if self.is_for_stream()
864 && !matches!(
865 self.context.clause,
866 Some(Clause::Where)
867 | Some(Clause::Having)
868 | Some(Clause::JoinOn)
869 | Some(Clause::From)
870 )
871 {
872 return Err(ErrorCode::InvalidInputSyntax(format!(
873 "For streaming queries, `NOW()` function is only allowed in `WHERE`, `HAVING`, `ON` and `FROM`. Found in clause: {:?}. \
874 Please refer to https://docs.risingwave.com/processing/sql/temporal-filters for more information",
875 self.context.clause
876 ))
877 .into());
878 }
879 if matches!(self.context.clause, Some(Clause::GeneratedColumn)) {
880 return Err(ErrorCode::InvalidInputSyntax(
881 "Cannot use `NOW()` function in generated columns. Do you want `PROCTIME()`?"
882 .to_owned(),
883 )
884 .into());
885 }
886 Ok(())
887 }
888
889 fn ensure_proctime_function_allowed(&self) -> Result<()> {
890 if !self.is_for_ddl() {
891 return Err(ErrorCode::InvalidInputSyntax(
892 "Function `PROCTIME()` is only allowed in CREATE TABLE/SOURCE. Is `NOW()` what you want?".to_owned(),
893 )
894 .into());
895 }
896 Ok(())
897 }
898}
899
900fn rewrite_concat_to_concat_ws(inputs: Vec<ExprImpl>) -> Result<Vec<ExprImpl>> {
901 if inputs.is_empty() {
902 Err(ErrorCode::BindError(
903 "Function `concat` takes at least 1 arguments (0 given)".to_owned(),
904 )
905 .into())
906 } else {
907 let inputs = std::iter::once(ExprImpl::literal_varchar("".to_owned()))
908 .chain(inputs)
909 .collect();
910 Ok(inputs)
911 }
912}
913
914fn rewrite_nullif_to_case_when(inputs: Vec<ExprImpl>) -> Result<Vec<ExprImpl>> {
917 if inputs.len() != 2 {
918 Err(ErrorCode::BindError("Function `nullif` must contain 2 arguments".to_owned()).into())
919 } else {
920 let inputs = vec![
921 FunctionCall::new(ExprType::Equal, inputs.clone())?.into(),
922 Literal::new(None, inputs[0].return_type()).into(),
923 inputs[0].clone(),
924 ];
925 Ok(inputs)
926 }
927}
928
929fn rewrite_two_bool_inputs(mut inputs: Vec<ExprImpl>) -> Result<Vec<ExprImpl>> {
930 if inputs.len() != 2 {
931 return Err(
932 ErrorCode::BindError("function must contain only 2 arguments".to_owned()).into(),
933 );
934 }
935 let left = inputs.pop().unwrap();
936 let right = inputs.pop().unwrap();
937 Ok(vec![
938 left.cast_implicit(&DataType::Boolean)?,
939 right.cast_implicit(&DataType::Boolean)?,
940 ])
941}