risingwave_frontend/binder/expr/function/
builtin_scalar.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::LazyLock;
17
18use bk_tree::{BKTree, metrics};
19use itertools::Itertools;
20use risingwave_common::session_config::USER_NAME_WILD_CARD;
21use risingwave_common::types::{DataType, ListValue, ScalarImpl, Timestamptz};
22use risingwave_common::{bail_not_implemented, current_cluster_version, no_function};
23use thiserror_ext::AsReport;
24
25use crate::Binder;
26use crate::binder::Clause;
27use crate::error::{ErrorCode, Result};
28use crate::expr::{CastContext, Expr, ExprImpl, ExprType, FunctionCall, Literal, Now};
29
30impl Binder {
31    pub(super) fn bind_builtin_scalar_function(
32        &mut self,
33        function_name: &str,
34        inputs: Vec<ExprImpl>,
35        variadic: bool,
36    ) -> Result<ExprImpl> {
37        type Inputs = Vec<ExprImpl>;
38
39        type Handle = Box<dyn Fn(&mut Binder, Inputs) -> Result<ExprImpl> + Sync + Send>;
40
41        fn rewrite(r#type: ExprType, rewriter: fn(Inputs) -> Result<Inputs>) -> Handle {
42            Box::new(move |_binder, mut inputs| {
43                inputs = (rewriter)(inputs)?;
44                Ok(FunctionCall::new(r#type, inputs)?.into())
45            })
46        }
47
48        fn raw_call(r#type: ExprType) -> Handle {
49            rewrite(r#type, Ok)
50        }
51
52        fn guard_by_len(expected_len: usize, handle: Handle) -> Handle {
53            Box::new(move |binder, inputs| {
54                if inputs.len() == expected_len {
55                    handle(binder, inputs)
56                } else {
57                    Err(ErrorCode::ExprError("unexpected arguments number".into()).into())
58                }
59            })
60        }
61
62        fn raw<F: Fn(&mut Binder, Inputs) -> Result<ExprImpl> + Sync + Send + 'static>(
63            f: F,
64        ) -> Handle {
65            Box::new(f)
66        }
67
68        fn dispatch_by_len(mapping: Vec<(usize, Handle)>) -> Handle {
69            Box::new(move |binder, inputs| {
70                for (len, handle) in &mapping {
71                    if inputs.len() == *len {
72                        return handle(binder, inputs);
73                    }
74                }
75                Err(ErrorCode::ExprError("unexpected arguments number".into()).into())
76            })
77        }
78
79        fn raw_literal(literal: ExprImpl) -> Handle {
80            Box::new(move |_binder, _inputs| Ok(literal.clone()))
81        }
82
83        fn now() -> Handle {
84            guard_by_len(
85                0,
86                raw(move |binder, _inputs| {
87                    binder.ensure_now_function_allowed()?;
88                    // NOTE: this will be further transformed during optimization. See the
89                    // documentation of `Now`.
90                    Ok(Now.into())
91                }),
92            )
93        }
94
95        fn pi() -> Handle {
96            raw_literal(ExprImpl::literal_f64(std::f64::consts::PI))
97        }
98
99        fn proctime() -> Handle {
100            Box::new(move |binder, inputs| {
101                binder.ensure_proctime_function_allowed()?;
102                raw_call(ExprType::Proctime)(binder, inputs)
103            })
104        }
105
106        // `SESSION_USER` is the user name of the user that is connected to the database.
107        fn session_user() -> Handle {
108            guard_by_len(
109                0,
110                raw(|binder, _inputs| {
111                    Ok(ExprImpl::literal_varchar(
112                        binder.auth_context.user_name.clone(),
113                    ))
114                }),
115            )
116        }
117
118        // `CURRENT_USER` is the user name of the user that is executing the command,
119        // `CURRENT_ROLE`, `USER` are synonyms for `CURRENT_USER`. Since we don't support
120        // `SET ROLE xxx` for now, they will all returns session user name.
121        fn current_user() -> Handle {
122            guard_by_len(
123                0,
124                raw(|binder, _inputs| {
125                    Ok(ExprImpl::literal_varchar(
126                        binder.auth_context.user_name.clone(),
127                    ))
128                }),
129            )
130        }
131
132        // `CURRENT_DATABASE` is the name of the database you are currently connected to.
133        // `CURRENT_CATALOG` is a synonym for `CURRENT_DATABASE`.
134        fn current_database() -> Handle {
135            guard_by_len(
136                0,
137                raw(|binder, _inputs| Ok(ExprImpl::literal_varchar(binder.db_name.clone()))),
138            )
139        }
140
141        // XXX: can we unify this with FUNC_SIG_MAP?
142        // For raw_call here, it seems unnecessary to declare it again here.
143        // For some functions, we have validation logic here. Is it still useful now?
144        static HANDLES: LazyLock<HashMap<&'static str, Handle>> = LazyLock::new(|| {
145            [
146                (
147                    "booleq",
148                    rewrite(ExprType::Equal, rewrite_two_bool_inputs),
149                ),
150                (
151                    "boolne",
152                    rewrite(ExprType::NotEqual, rewrite_two_bool_inputs),
153                ),
154                ("coalesce", rewrite(ExprType::Coalesce, |inputs| {
155                    if inputs.iter().any(ExprImpl::has_table_function) {
156                        return Err(ErrorCode::BindError("table functions are not allowed in COALESCE".into()).into());
157                    }
158                    Ok(inputs)
159                })),
160                (
161                    "nullif",
162                    rewrite(ExprType::Case, rewrite_nullif_to_case_when),
163                ),
164                (
165                    "round",
166                    dispatch_by_len(vec![
167                        (2, raw_call(ExprType::RoundDigit)),
168                        (1, raw_call(ExprType::Round)),
169                    ]),
170                ),
171                ("pow", raw_call(ExprType::Pow)),
172                // "power" is the function name used in PG.
173                ("power", raw_call(ExprType::Pow)),
174                ("ceil", raw_call(ExprType::Ceil)),
175                ("ceiling", raw_call(ExprType::Ceil)),
176                ("floor", raw_call(ExprType::Floor)),
177                ("trunc", raw_call(ExprType::Trunc)),
178                ("abs", raw_call(ExprType::Abs)),
179                ("exp", raw_call(ExprType::Exp)),
180                ("ln", raw_call(ExprType::Ln)),
181                ("log", raw_call(ExprType::Log10)),
182                ("log10", raw_call(ExprType::Log10)),
183                ("mod", raw_call(ExprType::Modulus)),
184                ("sin", raw_call(ExprType::Sin)),
185                ("cos", raw_call(ExprType::Cos)),
186                ("tan", raw_call(ExprType::Tan)),
187                ("cot", raw_call(ExprType::Cot)),
188                ("asin", raw_call(ExprType::Asin)),
189                ("acos", raw_call(ExprType::Acos)),
190                ("atan", raw_call(ExprType::Atan)),
191                ("atan2", raw_call(ExprType::Atan2)),
192                ("sind", raw_call(ExprType::Sind)),
193                ("cosd", raw_call(ExprType::Cosd)),
194                ("cotd", raw_call(ExprType::Cotd)),
195                ("tand", raw_call(ExprType::Tand)),
196                ("sinh", raw_call(ExprType::Sinh)),
197                ("cosh", raw_call(ExprType::Cosh)),
198                ("tanh", raw_call(ExprType::Tanh)),
199                ("coth", raw_call(ExprType::Coth)),
200                ("asinh", raw_call(ExprType::Asinh)),
201                ("acosh", raw_call(ExprType::Acosh)),
202                ("atanh", raw_call(ExprType::Atanh)),
203                ("asind", raw_call(ExprType::Asind)),
204                ("acosd", raw_call(ExprType::Acosd)),
205                ("atand", raw_call(ExprType::Atand)),
206                ("atan2d", raw_call(ExprType::Atan2d)),
207                ("degrees", raw_call(ExprType::Degrees)),
208                ("radians", raw_call(ExprType::Radians)),
209                ("sqrt", raw_call(ExprType::Sqrt)),
210                ("cbrt", raw_call(ExprType::Cbrt)),
211                ("sign", raw_call(ExprType::Sign)),
212                ("scale", raw_call(ExprType::Scale)),
213                ("min_scale", raw_call(ExprType::MinScale)),
214                ("trim_scale", raw_call(ExprType::TrimScale)),
215                // date and time
216                (
217                    "to_timestamp",
218                    dispatch_by_len(vec![
219                        (1, raw_call(ExprType::SecToTimestamptz)),
220                        (2, raw_call(ExprType::CharToTimestamptz)),
221                    ]),
222                ),
223                ("date_trunc", raw_call(ExprType::DateTrunc)),
224                ("date_bin", raw_call(ExprType::DateBin)),
225                ("date_part", raw_call(ExprType::DatePart)),
226                ("make_date", raw_call(ExprType::MakeDate)),
227                ("make_time", raw_call(ExprType::MakeTime)),
228                ("make_timestamp", raw_call(ExprType::MakeTimestamp)),
229                ("make_timestamptz", raw_call(ExprType::MakeTimestamptz)),
230                ("timezone", rewrite(ExprType::AtTimeZone, |mut inputs| {
231                    if inputs.len() == 2 {
232                        inputs.swap(0, 1);
233                        Ok(inputs)
234                    } else {
235                        Err(ErrorCode::ExprError("unexpected arguments number".into()).into())
236                    }
237                })),
238                ("to_date", raw_call(ExprType::CharToDate)),
239                // string
240                ("substr", raw_call(ExprType::Substr)),
241                ("length", raw_call(ExprType::Length)),
242                ("upper", raw_call(ExprType::Upper)),
243                ("lower", raw_call(ExprType::Lower)),
244                ("trim", raw_call(ExprType::Trim)),
245                ("replace", raw_call(ExprType::Replace)),
246                ("overlay", raw_call(ExprType::Overlay)),
247                ("btrim", raw_call(ExprType::Trim)),
248                ("ltrim", raw_call(ExprType::Ltrim)),
249                ("rtrim", raw_call(ExprType::Rtrim)),
250                ("md5", raw_call(ExprType::Md5)),
251                ("to_char", raw_call(ExprType::ToChar)),
252                (
253                    "concat",
254                    rewrite(ExprType::ConcatWs, rewrite_concat_to_concat_ws),
255                ),
256                ("concat_ws", raw_call(ExprType::ConcatWs)),
257                ("format", raw_call(ExprType::Format)),
258                ("translate", raw_call(ExprType::Translate)),
259                ("split_part", raw_call(ExprType::SplitPart)),
260                ("char_length", raw_call(ExprType::CharLength)),
261                ("character_length", raw_call(ExprType::CharLength)),
262                ("repeat", raw_call(ExprType::Repeat)),
263                ("ascii", raw_call(ExprType::Ascii)),
264                ("octet_length", raw_call(ExprType::OctetLength)),
265                ("bit_length", raw_call(ExprType::BitLength)),
266                ("regexp_match", raw_call(ExprType::RegexpMatch)),
267                ("regexp_replace", raw_call(ExprType::RegexpReplace)),
268                ("regexp_count", raw_call(ExprType::RegexpCount)),
269                ("regexp_split_to_array", raw_call(ExprType::RegexpSplitToArray)),
270                ("chr", raw_call(ExprType::Chr)),
271                ("starts_with", raw_call(ExprType::StartsWith)),
272                ("initcap", raw_call(ExprType::Initcap)),
273                ("lpad", raw_call(ExprType::Lpad)),
274                ("rpad", raw_call(ExprType::Rpad)),
275                ("reverse", raw_call(ExprType::Reverse)),
276                ("strpos", raw_call(ExprType::Position)),
277                ("to_ascii", raw_call(ExprType::ToAscii)),
278                ("to_hex", raw_call(ExprType::ToHex)),
279                ("quote_ident", raw_call(ExprType::QuoteIdent)),
280                ("quote_literal", guard_by_len(1, raw(|_binder, mut inputs| {
281                    if inputs[0].return_type() != DataType::Varchar {
282                        // Support `quote_literal(any)` by converting it to `quote_literal(any::text)`
283                        // Ref. https://github.com/postgres/postgres/blob/REL_16_1/src/include/catalog/pg_proc.dat#L4641
284                        FunctionCall::cast_mut(&mut inputs[0], &DataType::Varchar, CastContext::Explicit)?;
285                    }
286                    Ok(FunctionCall::new_unchecked(ExprType::QuoteLiteral, inputs, DataType::Varchar).into())
287                }))),
288                ("quote_nullable", guard_by_len(1, raw(|_binder, mut inputs| {
289                    if inputs[0].return_type() != DataType::Varchar {
290                        // Support `quote_nullable(any)` by converting it to `quote_nullable(any::text)`
291                        // Ref. https://github.com/postgres/postgres/blob/REL_16_1/src/include/catalog/pg_proc.dat#L4650
292                        FunctionCall::cast_mut(&mut inputs[0], &DataType::Varchar, CastContext::Explicit)?;
293                    }
294                    Ok(FunctionCall::new_unchecked(ExprType::QuoteNullable, inputs, DataType::Varchar).into())
295                }))),
296                ("string_to_array", raw_call(ExprType::StringToArray)),
297                ("encode", raw_call(ExprType::Encode)),
298                ("decode", raw_call(ExprType::Decode)),
299                ("convert_from", raw_call(ExprType::ConvertFrom)),
300                ("convert_to", raw_call(ExprType::ConvertTo)),
301                ("sha1", raw_call(ExprType::Sha1)),
302                ("sha224", raw_call(ExprType::Sha224)),
303                ("sha256", raw_call(ExprType::Sha256)),
304                ("sha384", raw_call(ExprType::Sha384)),
305                ("sha512", raw_call(ExprType::Sha512)),
306                ("encrypt", raw_call(ExprType::Encrypt)),
307                ("decrypt", raw_call(ExprType::Decrypt)),
308                ("hmac", raw_call(ExprType::Hmac)),
309                ("secure_compare", raw_call(ExprType::SecureCompare)),
310                ("left", raw_call(ExprType::Left)),
311                ("right", raw_call(ExprType::Right)),
312                ("inet_aton", raw_call(ExprType::InetAton)),
313                ("inet_ntoa", raw_call(ExprType::InetNtoa)),
314                ("int8send", raw_call(ExprType::PgwireSend)),
315                ("int8recv", guard_by_len(1, raw(|_binder, mut inputs| {
316                    // Similar to `cast` from string, return type is set explicitly rather than inferred.
317                    let hint = if !inputs[0].is_untyped() && inputs[0].return_type() == DataType::Varchar {
318                        " Consider `decode` or cast."
319                    } else {
320                        ""
321                    };
322                    inputs[0].cast_implicit_mut(&DataType::Bytea).map_err(|e| {
323                        ErrorCode::BindError(format!("{} in `recv`.{hint}", e.as_report()))
324                    })?;
325                    Ok(FunctionCall::new_unchecked(ExprType::PgwireRecv, inputs, DataType::Int64).into())
326                }))),
327                // array
328                ("array_cat", raw_call(ExprType::ArrayCat)),
329                ("array_append", raw_call(ExprType::ArrayAppend)),
330                ("array_join", raw_call(ExprType::ArrayToString)),
331                ("array_prepend", raw_call(ExprType::ArrayPrepend)),
332                ("array_to_string", raw_call(ExprType::ArrayToString)),
333                ("array_distinct", raw_call(ExprType::ArrayDistinct)),
334                ("array_min", raw_call(ExprType::ArrayMin)),
335                ("array_sort", raw_call(ExprType::ArraySort)),
336                ("array_length", raw_call(ExprType::ArrayLength)),
337                ("cardinality", raw_call(ExprType::Cardinality)),
338                ("array_remove", raw_call(ExprType::ArrayRemove)),
339                ("array_replace", raw_call(ExprType::ArrayReplace)),
340                ("array_max", raw_call(ExprType::ArrayMax)),
341                ("array_sum", raw_call(ExprType::ArraySum)),
342                ("array_position", raw_call(ExprType::ArrayPosition)),
343                ("array_positions", raw_call(ExprType::ArrayPositions)),
344                ("array_contains", raw_call(ExprType::ArrayContains)),
345                ("arraycontains", raw_call(ExprType::ArrayContains)),
346                ("array_contained", raw_call(ExprType::ArrayContained)),
347                ("arraycontained", raw_call(ExprType::ArrayContained)),
348                ("array_flatten", guard_by_len(1, raw(|_binder, inputs| {
349                    inputs[0].ensure_array_type().map_err(|_| ErrorCode::BindError("array_flatten expects `any[][]` input".into()))?;
350                    let return_type = inputs[0].return_type().into_list_element_type();
351                    if !return_type.is_array() {
352                        return Err(ErrorCode::BindError("array_flatten expects `any[][]` input".into()).into());
353
354                    }
355                    Ok(FunctionCall::new_unchecked(ExprType::ArrayFlatten, inputs, return_type).into())
356                }))),
357                ("trim_array", raw_call(ExprType::TrimArray)),
358                (
359                    "array_ndims",
360                    guard_by_len(1, raw(|_binder, inputs| {
361                        inputs[0].ensure_array_type()?;
362
363                        let n = inputs[0].return_type().array_ndims()
364                            .try_into().map_err(|_| ErrorCode::BindError("array_ndims integer overflow".into()))?;
365                        Ok(ExprImpl::literal_int(n))
366                    })),
367                ),
368                (
369                    "array_lower",
370                    guard_by_len(2, raw(|binder, inputs| {
371                        let (arg0, arg1) = inputs.into_iter().next_tuple().unwrap();
372                        // rewrite into `CASE WHEN 0 < arg1 AND arg1 <= array_ndims(arg0) THEN 1 END`
373                        let ndims_expr = binder.bind_builtin_scalar_function("array_ndims", vec![arg0], false)?;
374                        let arg1 = arg1.cast_implicit(&DataType::Int32)?;
375
376                        FunctionCall::new(
377                            ExprType::Case,
378                            vec![
379                                FunctionCall::new(
380                                    ExprType::And,
381                                    vec![
382                                        FunctionCall::new(ExprType::LessThan, vec![ExprImpl::literal_int(0), arg1.clone()])?.into(),
383                                        FunctionCall::new(ExprType::LessThanOrEqual, vec![arg1, ndims_expr])?.into(),
384                                    ],
385                                )?.into(),
386                                ExprImpl::literal_int(1),
387                            ],
388                        ).map(Into::into)
389                    })),
390                ),
391                ("array_upper", raw_call(ExprType::ArrayLength)), // `lower == 1` implies `upper == length`
392                ("array_dims", raw_call(ExprType::ArrayDims)),
393                // int256
394                ("hex_to_int256", raw_call(ExprType::HexToInt256)),
395                // jsonb
396                ("jsonb_object_field", raw_call(ExprType::JsonbAccess)),
397                ("jsonb_array_element", raw_call(ExprType::JsonbAccess)),
398                ("jsonb_object_field_text", raw_call(ExprType::JsonbAccessStr)),
399                ("jsonb_array_element_text", raw_call(ExprType::JsonbAccessStr)),
400                ("jsonb_extract_path", raw_call(ExprType::JsonbExtractPath)),
401                ("jsonb_extract_path_text", raw_call(ExprType::JsonbExtractPathText)),
402                ("jsonb_typeof", raw_call(ExprType::JsonbTypeof)),
403                ("jsonb_array_length", raw_call(ExprType::JsonbArrayLength)),
404                ("jsonb_concat", raw_call(ExprType::JsonbConcat)),
405                ("jsonb_object", raw_call(ExprType::JsonbObject)),
406                ("jsonb_pretty", raw_call(ExprType::JsonbPretty)),
407                ("jsonb_contains", raw_call(ExprType::JsonbContains)),
408                ("jsonb_contained", raw_call(ExprType::JsonbContained)),
409                ("jsonb_exists", raw_call(ExprType::JsonbExists)),
410                ("jsonb_exists_any", raw_call(ExprType::JsonbExistsAny)),
411                ("jsonb_exists_all", raw_call(ExprType::JsonbExistsAll)),
412                ("jsonb_delete", raw_call(ExprType::Subtract)),
413                ("jsonb_delete_path", raw_call(ExprType::JsonbDeletePath)),
414                ("jsonb_strip_nulls", raw_call(ExprType::JsonbStripNulls)),
415                ("to_jsonb", raw_call(ExprType::ToJsonb)),
416                ("jsonb_build_array", raw_call(ExprType::JsonbBuildArray)),
417                ("jsonb_build_object", raw_call(ExprType::JsonbBuildObject)),
418                ("jsonb_populate_record", raw_call(ExprType::JsonbPopulateRecord)),
419                ("jsonb_path_match", raw_call(ExprType::JsonbPathMatch)),
420                ("jsonb_path_exists", raw_call(ExprType::JsonbPathExists)),
421                ("jsonb_path_query_array", raw_call(ExprType::JsonbPathQueryArray)),
422                ("jsonb_path_query_first", raw_call(ExprType::JsonbPathQueryFirst)),
423                ("jsonb_set", raw_call(ExprType::JsonbSet)),
424                ("jsonb_populate_map", raw_call(ExprType::JsonbPopulateMap)),
425                ("jsonb_to_array", raw_call(ExprType::JsonbToArray)),
426                // map
427                ("map_from_entries", raw_call(ExprType::MapFromEntries)),
428                ("map_access", raw_call(ExprType::MapAccess)),
429                ("map_keys", raw_call(ExprType::MapKeys)),
430                ("map_values", raw_call(ExprType::MapValues)),
431                ("map_entries", raw_call(ExprType::MapEntries)),
432                ("map_from_key_values", raw_call(ExprType::MapFromKeyValues)),
433                ("map_cat", raw_call(ExprType::MapCat)),
434                ("map_contains", raw_call(ExprType::MapContains)),
435                ("map_delete", raw_call(ExprType::MapDelete)),
436                ("map_insert", raw_call(ExprType::MapInsert)),
437                ("map_length", raw_call(ExprType::MapLength)),
438                // vector
439                ("l2_distance", raw_call(ExprType::L2Distance)),
440                ("cosine_distance", raw_call(ExprType::CosineDistance)),
441                ("l1_distance", raw_call(ExprType::L1Distance)),
442                ("inner_product", raw_call(ExprType::InnerProduct)),
443                ("vector_norm", raw_call(ExprType::L2Norm)),
444                ("l2_normalize", raw_call(ExprType::L2Normalize)),
445                // Functions that return a constant value
446                ("pi", pi()),
447                // greatest and least
448                ("greatest", raw_call(ExprType::Greatest)),
449                ("least", raw_call(ExprType::Least)),
450                // System information operations.
451                (
452                    "pg_typeof",
453                    guard_by_len(1, raw(|_binder, inputs| {
454                        let input = &inputs[0];
455                        let v = match input.is_untyped() {
456                            true => "unknown".into(),
457                            false => input.return_type().to_string(),
458                        };
459                        Ok(ExprImpl::literal_varchar(v))
460                    })),
461                ),
462                ("current_catalog", current_database()),
463                ("current_database", current_database()),
464                ("current_schema", guard_by_len(0, raw(|binder, _inputs| {
465                    Ok(binder
466                        .first_valid_schema()
467                        .map(|schema| ExprImpl::literal_varchar(schema.name()))
468                        .unwrap_or_else(|_| ExprImpl::literal_null(DataType::Varchar)))
469                }))),
470                ("current_schemas", raw(|binder, mut inputs| {
471                    let no_match_err = ErrorCode::ExprError(
472                        "No function matches the given name and argument types. You might need to add explicit type casts.".into()
473                    );
474                    if inputs.len() != 1 {
475                        return Err(no_match_err.into());
476                    }
477                    let input = inputs
478                        .pop()
479                        .unwrap()
480                        .enforce_bool_clause("current_schemas")
481                        .map_err(|_| no_match_err)?;
482
483                    let ExprImpl::Literal(literal) = &input else {
484                        bail_not_implemented!("Only boolean literals are supported in `current_schemas`.");
485                    };
486
487                    let Some(bool) = literal.get_data().as_ref().map(|bool| bool.clone().into_bool()) else {
488                        return Ok(ExprImpl::literal_null(DataType::List(Box::new(DataType::Varchar))));
489                    };
490
491                    let paths = if bool {
492                        binder.search_path.path()
493                    } else {
494                        binder.search_path.real_path()
495                    };
496
497                    let mut schema_names = vec![];
498                    for path in paths {
499                        let mut schema_name = path;
500                        if schema_name == USER_NAME_WILD_CARD {
501                            schema_name = &binder.auth_context.user_name;
502                        }
503
504                        if binder
505                            .catalog
506                            .get_schema_by_name(&binder.db_name, schema_name)
507                            .is_ok()
508                        {
509                            schema_names.push(schema_name.as_str());
510                        }
511                    }
512
513                    Ok(ExprImpl::literal_list(
514                        ListValue::from_iter(schema_names),
515                        DataType::Varchar,
516                    ))
517                })),
518                ("session_user", session_user()),
519                ("current_role", current_user()),
520                ("current_user", current_user()),
521                ("user", current_user()),
522                ("pg_get_userbyid", raw_call(ExprType::PgGetUserbyid)),
523                ("pg_get_indexdef", raw_call(ExprType::PgGetIndexdef)),
524                ("pg_get_viewdef", raw_call(ExprType::PgGetViewdef)),
525                ("pg_index_column_has_property", raw_call(ExprType::PgIndexColumnHasProperty)),
526                ("pg_relation_size", raw(|_binder, mut inputs| {
527                    if inputs.is_empty() {
528                        return Err(ErrorCode::ExprError(
529                            "function pg_relation_size() does not exist".into(),
530                        )
531                            .into());
532                    }
533                    inputs[0].cast_to_regclass_mut()?;
534                    Ok(FunctionCall::new(ExprType::PgRelationSize, inputs)?.into())
535                })),
536                ("pg_get_serial_sequence", raw_literal(ExprImpl::literal_null(DataType::Varchar))),
537                ("pg_table_size", guard_by_len(1, raw(|_binder, mut inputs| {
538                    inputs[0].cast_to_regclass_mut()?;
539                    Ok(FunctionCall::new(ExprType::PgRelationSize, inputs)?.into())
540                }))),
541                ("pg_indexes_size", guard_by_len(1, raw(|_binder, mut inputs| {
542                    inputs[0].cast_to_regclass_mut()?;
543                    Ok(FunctionCall::new(ExprType::PgIndexesSize, inputs)?.into())
544                }))),
545                ("pg_get_expr", raw(|_binder, inputs| {
546                    if inputs.len() == 2 || inputs.len() == 3 {
547                        // TODO: implement pg_get_expr rather than just return empty as an workaround.
548                        Ok(ExprImpl::literal_varchar("".into()))
549                    } else {
550                        Err(ErrorCode::ExprError(
551                            "Too many/few arguments for pg_catalog.pg_get_expr()".into(),
552                        )
553                            .into())
554                    }
555                })),
556                ("pg_my_temp_schema", guard_by_len(0, raw(|_binder, _inputs| {
557                    // Returns the OID of the current session's temporary schema, or zero if it has none (because it has not created any temporary tables).
558                    Ok(ExprImpl::literal_int(
559                        // always return 0, as we haven't supported temporary tables nor temporary schema yet
560                        0,
561                    ))
562                }))),
563                ("current_setting", guard_by_len(1, raw(|binder, inputs| {
564                    let input = &inputs[0];
565                    let input = if let ExprImpl::Literal(literal) = input &&
566                        let Some(ScalarImpl::Utf8(input)) = literal.get_data()
567                    {
568                        input
569                    } else {
570                        return Err(ErrorCode::ExprError(
571                            "Only literal is supported in `setting_name`.".into(),
572                        )
573                            .into());
574                    };
575                    let session_config = binder.session_config.read();
576                    Ok(ExprImpl::literal_varchar(session_config.get(input.as_ref())?))
577                }))),
578                ("set_config", guard_by_len(3, raw(|binder, inputs| {
579                    let setting_name = if let ExprImpl::Literal(literal) = &inputs[0] && let Some(ScalarImpl::Utf8(input)) = literal.get_data() {
580                        input
581                    } else {
582                        return Err(ErrorCode::ExprError(
583                            "Only string literal is supported in `setting_name`.".into(),
584                        )
585                            .into());
586                    };
587
588                    let new_value = if let ExprImpl::Literal(literal) = &inputs[1] && let Some(ScalarImpl::Utf8(input)) = literal.get_data() {
589                        input
590                    } else {
591                        return Err(ErrorCode::ExprError(
592                            "Only string literal is supported in `setting_name`.".into(),
593                        )
594                            .into());
595                    };
596
597                    let is_local = if let ExprImpl::Literal(literal) = &inputs[2] && let Some(ScalarImpl::Bool(input)) = literal.get_data() {
598                        input
599                    } else {
600                        return Err(ErrorCode::ExprError(
601                            "Only bool literal is supported in `is_local`.".into(),
602                        )
603                            .into());
604                    };
605
606                    if *is_local {
607                        return Err(ErrorCode::ExprError(
608                            "`is_local = true` is not supported now.".into(),
609                        )
610                            .into());
611                    }
612
613                    let mut session_config = binder.session_config.write();
614
615                    // TODO: report session config changes if necessary.
616                    session_config.set(setting_name, new_value.to_string(), &mut ())?;
617
618                    Ok(ExprImpl::literal_varchar(new_value.to_string()))
619                }))),
620                ("format_type", raw_call(ExprType::FormatType)),
621                ("pg_table_is_visible", raw_call(ExprType::PgTableIsVisible)),
622                ("pg_type_is_visible", raw_literal(ExprImpl::literal_bool(true))),
623                ("pg_get_constraintdef", raw_literal(ExprImpl::literal_null(DataType::Varchar))),
624                ("pg_get_partkeydef", raw_literal(ExprImpl::literal_null(DataType::Varchar))),
625                ("pg_encoding_to_char", raw_literal(ExprImpl::literal_varchar("UTF8".into()))),
626                ("has_database_privilege", raw(|binder, mut inputs| {
627                    if inputs.len() == 2 {
628                        inputs.insert(0, ExprImpl::literal_varchar(binder.auth_context.user_name.clone()));
629                    }
630                    if inputs.len() == 3 {
631                        Ok(FunctionCall::new(ExprType::HasDatabasePrivilege, inputs)?.into())
632                    } else {
633                        Err(ErrorCode::ExprError(
634                            "Too many/few arguments for pg_catalog.has_database_privilege()".into(),
635                        )
636                            .into())
637                    }
638                })),
639                ("has_table_privilege", raw(|binder, mut inputs| {
640                    if inputs.len() == 2 {
641                        inputs.insert(0, ExprImpl::literal_varchar(binder.auth_context.user_name.clone()));
642                    }
643                    if inputs.len() == 3 {
644                        if inputs[1].return_type() == DataType::Varchar {
645                            inputs[1].cast_to_regclass_mut()?;
646                        }
647                        Ok(FunctionCall::new(ExprType::HasTablePrivilege, inputs)?.into())
648                    } else {
649                        Err(ErrorCode::ExprError(
650                            "Too many/few arguments for pg_catalog.has_table_privilege()".into(),
651                        )
652                            .into())
653                    }
654                })),
655                ("has_any_column_privilege", raw(|binder, mut inputs| {
656                    if inputs.len() == 2 {
657                        inputs.insert(0, ExprImpl::literal_varchar(binder.auth_context.user_name.clone()));
658                    }
659                    if inputs.len() == 3 {
660                        if inputs[1].return_type() == DataType::Varchar {
661                            inputs[1].cast_to_regclass_mut()?;
662                        }
663                        Ok(FunctionCall::new(ExprType::HasAnyColumnPrivilege, inputs)?.into())
664                    } else {
665                        Err(ErrorCode::ExprError(
666                            "Too many/few arguments for pg_catalog.has_any_column_privilege()".into(),
667                        )
668                            .into())
669                    }
670                })),
671                ("has_schema_privilege", raw(|binder, mut inputs| {
672                    if inputs.len() == 2 {
673                        inputs.insert(0, ExprImpl::literal_varchar(binder.auth_context.user_name.clone()));
674                    }
675                    if inputs.len() == 3 {
676                        Ok(FunctionCall::new(ExprType::HasSchemaPrivilege, inputs)?.into())
677                    } else {
678                        Err(ErrorCode::ExprError(
679                            "Too many/few arguments for pg_catalog.has_schema_privilege()".into(),
680                        )
681                            .into())
682                    }
683                })),
684                ("has_function_privilege", raw(|binder, mut inputs| {
685                    if inputs.len() == 2 {
686                        inputs.insert(0, ExprImpl::literal_varchar(binder.auth_context.user_name.clone()));
687                    }
688                    if inputs.len() == 3 {
689                        Ok(FunctionCall::new(ExprType::HasFunctionPrivilege, inputs)?.into())
690                    } else {
691                        Err(ErrorCode::ExprError(
692                            "Too many/few arguments for pg_catalog.has_function_privilege()".into(),
693                        )
694                            .into())
695                    }
696                })),
697                ("pg_stat_get_numscans", raw_literal(ExprImpl::literal_bigint(0))),
698                ("pg_backend_pid", raw(|binder, _inputs| {
699                    // FIXME: the session id is not global unique in multi-frontend env.
700                    Ok(ExprImpl::literal_int(binder.session_id.0))
701                })),
702                ("pg_cancel_backend", guard_by_len(1, raw(|_binder, _inputs| {
703                    // TODO: implement real cancel rather than just return false as an workaround.
704                    Ok(ExprImpl::literal_bool(false))
705                }))),
706                ("pg_terminate_backend", guard_by_len(1, raw(|_binder, _inputs| {
707                    // TODO: implement real terminate rather than just return false as an
708                    // workaround.
709                    Ok(ExprImpl::literal_bool(false))
710                }))),
711                ("pg_tablespace_location", guard_by_len(1, raw_literal(ExprImpl::literal_null(DataType::Varchar)))),
712                ("pg_postmaster_start_time", guard_by_len(0, raw(|_binder, _inputs| {
713                    let server_start_time = risingwave_variables::get_server_start_time();
714                    let datum = server_start_time.map(Timestamptz::from).map(ScalarImpl::from);
715                    let literal = Literal::new(datum, DataType::Timestamptz);
716                    Ok(literal.into())
717                }))),
718                // TODO: really implement them.
719                // https://www.postgresql.org/docs/9.5/functions-info.html#FUNCTIONS-INFO-COMMENT-TABLE
720                // WARN: Hacked in [`Binder::bind_function`]!!!
721                ("col_description", raw_call(ExprType::ColDescription)),
722                ("obj_description", raw_literal(ExprImpl::literal_varchar("".to_owned()))),
723                ("shobj_description", raw_literal(ExprImpl::literal_varchar("".to_owned()))),
724                ("pg_is_in_recovery", raw_call(ExprType::PgIsInRecovery)),
725                ("rw_recovery_status", raw_call(ExprType::RwRecoveryStatus)),
726                ("rw_epoch_to_ts", raw_call(ExprType::RwEpochToTs)),
727                // internal
728                ("rw_vnode", raw_call(ExprType::VnodeUser)),
729                ("rw_license", raw_call(ExprType::License)),
730                ("rw_test_paid_tier", raw_call(ExprType::TestFeature)), // deprecated, kept for compatibility
731                ("rw_test_feature", raw_call(ExprType::TestFeature)), // for testing purposes
732                // TODO: choose which pg version we should return.
733                ("version", raw_literal(ExprImpl::literal_varchar(current_cluster_version()))),
734                // non-deterministic
735                ("now", now()),
736                ("current_timestamp", now()),
737                ("proctime", proctime()),
738                ("pg_sleep", raw_call(ExprType::PgSleep)),
739                ("pg_sleep_for", raw_call(ExprType::PgSleepFor)),
740                ("random", raw_call(ExprType::Random)),
741                // TODO: implement pg_sleep_until
742                // ("pg_sleep_until", raw_call(ExprType::PgSleepUntil)),
743
744                // cast functions
745                // only functions required by the existing PostgreSQL tool are implemented
746                ("date", guard_by_len(1, raw(|_binder, inputs| {
747                    inputs[0].clone().cast_explicit(&DataType::Date).map_err(Into::into)
748                }))),
749
750                // AI model functions
751                ("openai_embedding", guard_by_len(3, raw(|_binder, inputs| {
752                    // check if the first two arguments are constants
753                    if let ExprImpl::Literal(api_key) = &inputs[0] && let Some(ScalarImpl::Utf8(_api_key)) = api_key.get_data()
754                    && let ExprImpl::Literal(model) = &inputs[1] && let Some(ScalarImpl::Utf8(_model)) = model.get_data() {
755                        Ok(FunctionCall::new(ExprType::OpenaiEmbedding, inputs)?.into())
756                    } else {
757                        Err(ErrorCode::InvalidInputSyntax(
758                            "`api_key` and `model` must be constant strings".to_owned(),
759                        ).into())
760                    }
761                }))),
762            ]
763                .into_iter()
764                .collect()
765        });
766
767        static FUNCTIONS_BKTREE: LazyLock<BKTree<&str>> = LazyLock::new(|| {
768            let mut tree = BKTree::new(metrics::Levenshtein);
769
770            // TODO: Also hint other functinos, e.g., Agg or UDF.
771            for k in HANDLES.keys() {
772                tree.add(*k);
773            }
774
775            tree
776        });
777
778        if variadic {
779            let func = match function_name {
780                "format" => ExprType::FormatVariadic,
781                "concat" => ExprType::ConcatVariadic,
782                "concat_ws" => ExprType::ConcatWsVariadic,
783                "jsonb_build_array" => ExprType::JsonbBuildArrayVariadic,
784                "jsonb_build_object" => ExprType::JsonbBuildObjectVariadic,
785                "jsonb_extract_path" => ExprType::JsonbExtractPathVariadic,
786                "jsonb_extract_path_text" => ExprType::JsonbExtractPathTextVariadic,
787                _ => {
788                    return Err(ErrorCode::BindError(format!(
789                        "VARIADIC argument is not allowed in function \"{}\"",
790                        function_name
791                    ))
792                    .into());
793                }
794            };
795            return Ok(FunctionCall::new(func, inputs)?.into());
796        }
797
798        // Note: for raw_call, we only check name here. The type check is done later.
799        match HANDLES.get(function_name) {
800            Some(handle) => handle(self, inputs),
801            None => {
802                let allowed_distance = if function_name.len() > 3 { 2 } else { 1 };
803
804                let candidates = FUNCTIONS_BKTREE
805                    .find(function_name, allowed_distance)
806                    .map(|(_idx, c)| c)
807                    .join(" or ");
808
809                Err(no_function!(
810                    candidates = (!candidates.is_empty()).then_some(candidates),
811                    "{}({})",
812                    function_name,
813                    inputs.iter().map(|e| e.return_type()).join(", ")
814                )
815                .into())
816            }
817        }
818    }
819
820    fn ensure_now_function_allowed(&self) -> Result<()> {
821        if self.is_for_stream()
822            && !matches!(
823                self.context.clause,
824                Some(Clause::Where)
825                    | Some(Clause::Having)
826                    | Some(Clause::JoinOn)
827                    | Some(Clause::From)
828            )
829        {
830            return Err(ErrorCode::InvalidInputSyntax(format!(
831                "For streaming queries, `NOW()` function is only allowed in `WHERE`, `HAVING`, `ON` and `FROM`. Found in clause: {:?}. \
832                Please please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information",
833                self.context.clause
834            ))
835                .into());
836        }
837        if matches!(self.context.clause, Some(Clause::GeneratedColumn)) {
838            return Err(ErrorCode::InvalidInputSyntax(
839                "Cannot use `NOW()` function in generated columns. Do you want `PROCTIME()`?"
840                    .to_owned(),
841            )
842            .into());
843        }
844        Ok(())
845    }
846
847    fn ensure_proctime_function_allowed(&self) -> Result<()> {
848        if !self.is_for_ddl() {
849            return Err(ErrorCode::InvalidInputSyntax(
850                "Function `PROCTIME()` is only allowed in CREATE TABLE/SOURCE. Is `NOW()` what you want?".to_owned(),
851            )
852                .into());
853        }
854        Ok(())
855    }
856}
857
858fn rewrite_concat_to_concat_ws(inputs: Vec<ExprImpl>) -> Result<Vec<ExprImpl>> {
859    if inputs.is_empty() {
860        Err(ErrorCode::BindError(
861            "Function `concat` takes at least 1 arguments (0 given)".to_owned(),
862        )
863        .into())
864    } else {
865        let inputs = std::iter::once(ExprImpl::literal_varchar("".to_owned()))
866            .chain(inputs)
867            .collect();
868        Ok(inputs)
869    }
870}
871
872/// Make sure inputs only have 2 value and rewrite the arguments.
873/// Nullif(expr1,expr2) -> Case(Equal(expr1 = expr2),null,expr1).
874fn rewrite_nullif_to_case_when(inputs: Vec<ExprImpl>) -> Result<Vec<ExprImpl>> {
875    if inputs.len() != 2 {
876        Err(ErrorCode::BindError("Function `nullif` must contain 2 arguments".to_owned()).into())
877    } else {
878        let inputs = vec![
879            FunctionCall::new(ExprType::Equal, inputs.clone())?.into(),
880            Literal::new(None, inputs[0].return_type()).into(),
881            inputs[0].clone(),
882        ];
883        Ok(inputs)
884    }
885}
886
887fn rewrite_two_bool_inputs(mut inputs: Vec<ExprImpl>) -> Result<Vec<ExprImpl>> {
888    if inputs.len() != 2 {
889        return Err(
890            ErrorCode::BindError("function must contain only 2 arguments".to_owned()).into(),
891        );
892    }
893    let left = inputs.pop().unwrap();
894    let right = inputs.pop().unwrap();
895    Ok(vec![
896        left.cast_implicit(&DataType::Boolean)?,
897        right.cast_implicit(&DataType::Boolean)?,
898    ])
899}