risingwave_expr/sig/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Metadata of expressions.
16
17use std::borrow::Cow;
18use std::collections::HashMap;
19use std::fmt;
20use std::sync::LazyLock;
21
22use itertools::Itertools;
23use risingwave_common::types::DataType;
24use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
25use risingwave_pb::expr::expr_node::PbType as ScalarFunctionType;
26use risingwave_pb::expr::table_function::PbType as TableFunctionType;
27
28use crate::ExprError;
29use crate::aggregate::{AggCall, BoxedAggregateFunction};
30use crate::error::Result;
31use crate::expr::BoxedExpression;
32use crate::table_function::BoxedTableFunction;
33
34mod udf;
35
36pub use self::udf::*;
37
38/// The global registry of all function signatures.
39pub static FUNCTION_REGISTRY: LazyLock<FunctionRegistry> = LazyLock::new(|| {
40    let mut map = FunctionRegistry::default();
41    tracing::info!("found {} functions", FUNCTIONS.len());
42    for f in FUNCTIONS {
43        map.insert(f());
44    }
45    map
46});
47
48/// A set of function signatures.
49#[derive(Default, Clone, Debug)]
50pub struct FunctionRegistry(HashMap<FuncName, Vec<FuncSign>>);
51
52impl FunctionRegistry {
53    /// Inserts a function signature.
54    pub fn insert(&mut self, sig: FuncSign) {
55        let list = self.0.entry(sig.name.clone()).or_default();
56        if sig.is_aggregate() {
57            // merge retractable and append-only aggregate
58            if let Some(existing) = list
59                .iter_mut()
60                .find(|d| d.inputs_type == sig.inputs_type && d.ret_type == sig.ret_type)
61            {
62                let (
63                    FuncBuilder::Aggregate {
64                        retractable,
65                        append_only,
66                        retractable_state_type,
67                        append_only_state_type,
68                    },
69                    FuncBuilder::Aggregate {
70                        retractable: r1,
71                        append_only: a1,
72                        retractable_state_type: rs1,
73                        append_only_state_type: as1,
74                    },
75                ) = (&mut existing.build, sig.build)
76                else {
77                    panic!("expected aggregate function")
78                };
79                if let Some(f) = r1 {
80                    *retractable = Some(f);
81                    *retractable_state_type = rs1;
82                }
83                if let Some(f) = a1 {
84                    *append_only = Some(f);
85                    *append_only_state_type = as1;
86                }
87                return;
88            }
89        }
90        list.push(sig);
91    }
92
93    /// Remove a function signature from registry.
94    pub fn remove(&mut self, sig: FuncSign) -> Option<FuncSign> {
95        let pos = self
96            .0
97            .get_mut(&sig.name)?
98            .iter()
99            .positions(|s| s.inputs_type == sig.inputs_type && s.ret_type == sig.ret_type)
100            .rev()
101            .collect_vec();
102        let mut ret = None;
103        for p in pos {
104            ret = Some(self.0.get_mut(&sig.name)?.swap_remove(p));
105        }
106        ret
107    }
108
109    /// Returns a function signature with the same type, argument types and return type.
110    /// Deprecated functions are included.
111    pub fn get(
112        &self,
113        name: impl Into<FuncName>,
114        args: &[DataType],
115        ret: &DataType,
116    ) -> Result<&FuncSign, ExprError> {
117        let name = name.into();
118        let err = |candidates: &Vec<FuncSign>| {
119            // Note: if we return error here, it probably means there is a bug in frontend type inference,
120            // because such error should be caught in the frontend.
121            ExprError::UnsupportedFunction(format!(
122                "{}({}) -> {}{}",
123                name,
124                args.iter().format(", "),
125                ret,
126                if candidates.is_empty() {
127                    "".to_owned()
128                } else {
129                    format!(
130                        "\nHINT: Supported functions:\n{}",
131                        candidates
132                            .iter()
133                            .map(|d| format!(
134                                "  {}({}) -> {}",
135                                d.name,
136                                d.inputs_type.iter().format(", "),
137                                d.ret_type
138                            ))
139                            .format("\n")
140                    )
141                }
142            ))
143        };
144        let v = self.0.get(&name).ok_or_else(|| err(&vec![]))?;
145        v.iter()
146            .find(|d| d.match_args_ret(args, ret))
147            .ok_or_else(|| err(v))
148    }
149
150    /// Returns all function signatures with the same type and number of arguments.
151    /// Deprecated functions are excluded.
152    pub fn get_with_arg_nums(&self, name: impl Into<FuncName>, nargs: usize) -> Vec<&FuncSign> {
153        match self.0.get(&name.into()) {
154            Some(v) => v
155                .iter()
156                .filter(|d| d.match_number_of_args(nargs) && !d.deprecated)
157                .collect(),
158            None => vec![],
159        }
160    }
161
162    /// Returns the return type for the given function and arguments.
163    /// Deprecated functions are excluded.
164    pub fn get_return_type(
165        &self,
166        name: impl Into<FuncName>,
167        args: &[DataType],
168    ) -> Result<DataType> {
169        let name = name.into();
170        let v = self
171            .0
172            .get(&name)
173            .ok_or_else(|| ExprError::UnsupportedFunction(name.to_string()))?;
174        let sig = v
175            .iter()
176            .find(|d| d.match_args(args) && !d.deprecated)
177            .ok_or_else(|| ExprError::UnsupportedFunction(name.to_string()))?;
178        (sig.type_infer)(args)
179    }
180
181    /// Returns an iterator of all function signatures.
182    pub fn iter(&self) -> impl Iterator<Item = &FuncSign> {
183        self.0.values().flatten()
184    }
185
186    /// Returns an iterator of all scalar functions.
187    pub fn iter_scalars(&self) -> impl Iterator<Item = &FuncSign> {
188        self.iter().filter(|d| d.is_scalar())
189    }
190
191    /// Returns an iterator of all aggregate functions.
192    pub fn iter_aggregates(&self) -> impl Iterator<Item = &FuncSign> {
193        self.iter().filter(|d| d.is_aggregate())
194    }
195}
196
197/// A function signature.
198#[derive(Clone)]
199pub struct FuncSign {
200    /// The name of the function.
201    pub name: FuncName,
202
203    /// The argument types.
204    pub inputs_type: Vec<SigDataType>,
205
206    /// Whether the function is variadic.
207    pub variadic: bool,
208
209    /// The return type.
210    pub ret_type: SigDataType,
211
212    /// A function to build the expression.
213    pub build: FuncBuilder,
214
215    /// A function to infer the return type from argument types.
216    pub type_infer: fn(args: &[DataType]) -> Result<DataType>,
217
218    /// Whether the function is deprecated and should not be used in the frontend.
219    /// For backward compatibility, it is still available in the backend.
220    pub deprecated: bool,
221}
222
223impl fmt::Debug for FuncSign {
224    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225        write!(
226            f,
227            "{}({}{}) -> {}{}",
228            self.name.as_str_name().to_ascii_lowercase(),
229            self.inputs_type.iter().format(", "),
230            if self.variadic {
231                if self.inputs_type.is_empty() {
232                    "..."
233                } else {
234                    ", ..."
235                }
236            } else {
237                ""
238            },
239            if self.name.is_table() { "setof " } else { "" },
240            self.ret_type,
241        )?;
242        if self.deprecated {
243            write!(f, " [deprecated]")?;
244        }
245        Ok(())
246    }
247}
248
249impl FuncSign {
250    /// Returns true if the argument types match the function signature.
251    pub fn match_args(&self, args: &[DataType]) -> bool {
252        if !self.match_number_of_args(args.len()) {
253            return false;
254        }
255        // allow `zip` as the length of `args` may be larger than `inputs_type`
256        #[allow(clippy::disallowed_methods)]
257        self.inputs_type
258            .iter()
259            .zip(args.iter())
260            .all(|(matcher, arg)| matcher.matches(arg))
261    }
262
263    /// Returns true if the argument types match the function signature.
264    fn match_args_ret(&self, args: &[DataType], ret: &DataType) -> bool {
265        self.match_args(args) && self.ret_type.matches(ret)
266    }
267
268    /// Returns true if the number of arguments matches the function signature.
269    fn match_number_of_args(&self, n: usize) -> bool {
270        if self.variadic {
271            n >= self.inputs_type.len()
272        } else {
273            n == self.inputs_type.len()
274        }
275    }
276
277    /// Returns true if the function is a scalar function.
278    pub const fn is_scalar(&self) -> bool {
279        matches!(self.name, FuncName::Scalar(_))
280    }
281
282    /// Returns true if the function is a table function.
283    pub const fn is_table_function(&self) -> bool {
284        matches!(self.name, FuncName::Table(_))
285    }
286
287    /// Returns true if the function is a aggregate function.
288    pub const fn is_aggregate(&self) -> bool {
289        matches!(self.name, FuncName::Aggregate(_))
290    }
291
292    /// Returns true if the aggregate function is append-only.
293    pub const fn is_append_only(&self) -> bool {
294        matches!(
295            self.build,
296            FuncBuilder::Aggregate {
297                retractable: None,
298                ..
299            }
300        )
301    }
302
303    /// Returns true if the aggregate function has a retractable version.
304    pub const fn is_retractable(&self) -> bool {
305        matches!(
306            self.build,
307            FuncBuilder::Aggregate {
308                retractable: Some(_),
309                ..
310            }
311        )
312    }
313
314    /// Builds the scalar function.
315    pub fn build_scalar(
316        &self,
317        return_type: DataType,
318        children: Vec<BoxedExpression>,
319    ) -> Result<BoxedExpression> {
320        match self.build {
321            FuncBuilder::Scalar(f) => f(return_type, children),
322            _ => panic!("Expected a scalar function"),
323        }
324    }
325
326    /// Builds the table function.
327    pub fn build_table(
328        &self,
329        return_type: DataType,
330        chunk_size: usize,
331        children: Vec<BoxedExpression>,
332    ) -> Result<BoxedTableFunction> {
333        match self.build {
334            FuncBuilder::Table(f) => f(return_type, chunk_size, children),
335            _ => panic!("Expected a table function"),
336        }
337    }
338
339    /// Builds the aggregate function. If both retractable and append-only versions exist, the
340    /// retractable version will be built.
341    pub fn build_aggregate(&self, agg: &AggCall) -> Result<BoxedAggregateFunction> {
342        match self.build {
343            FuncBuilder::Aggregate {
344                retractable,
345                append_only,
346                ..
347            } => retractable.or(append_only).unwrap()(agg),
348            _ => panic!("Expected an aggregate function"),
349        }
350    }
351}
352
353#[derive(Debug, Clone, PartialEq, Eq, Hash)]
354pub enum FuncName {
355    Scalar(ScalarFunctionType),
356    Table(TableFunctionType),
357    Aggregate(PbAggKind),
358    Udf(String),
359}
360
361impl From<ScalarFunctionType> for FuncName {
362    fn from(ty: ScalarFunctionType) -> Self {
363        Self::Scalar(ty)
364    }
365}
366
367impl From<TableFunctionType> for FuncName {
368    fn from(ty: TableFunctionType) -> Self {
369        Self::Table(ty)
370    }
371}
372
373impl From<PbAggKind> for FuncName {
374    fn from(ty: PbAggKind) -> Self {
375        Self::Aggregate(ty)
376    }
377}
378
379impl fmt::Display for FuncName {
380    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381        write!(f, "{}", self.as_str_name().to_ascii_lowercase())
382    }
383}
384
385impl FuncName {
386    /// Returns the name of the function in `UPPER_CASE` style.
387    pub fn as_str_name(&self) -> Cow<'static, str> {
388        match self {
389            Self::Scalar(ty) => ty.as_str_name().into(),
390            Self::Table(ty) => ty.as_str_name().into(),
391            Self::Aggregate(ty) => ty.as_str_name().into(),
392            Self::Udf(name) => name.clone().into(),
393        }
394    }
395
396    /// Returns true if the function is a table function.
397    const fn is_table(&self) -> bool {
398        matches!(self, Self::Table(_))
399    }
400
401    pub fn as_scalar(&self) -> ScalarFunctionType {
402        match self {
403            Self::Scalar(ty) => *ty,
404            _ => panic!("Expected a scalar function"),
405        }
406    }
407
408    pub fn as_aggregate(&self) -> PbAggKind {
409        match self {
410            Self::Aggregate(kind) => *kind,
411            _ => panic!("Expected an aggregate function"),
412        }
413    }
414}
415
416/// An extended data type that can be used to declare a function's argument or result type.
417#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
418pub enum SigDataType {
419    /// Exact data type
420    Exact(DataType),
421    /// Accepts any data type
422    Any,
423    /// Accepts any array data type
424    AnyArray,
425    /// Accepts any struct data type
426    AnyStruct,
427    /// TODO: not all type can be used as a map key.
428    AnyMap,
429}
430
431impl From<DataType> for SigDataType {
432    fn from(dt: DataType) -> Self {
433        SigDataType::Exact(dt)
434    }
435}
436
437impl std::fmt::Display for SigDataType {
438    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439        match self {
440            Self::Exact(dt) => write!(f, "{}", dt),
441            Self::Any => write!(f, "any"),
442            Self::AnyArray => write!(f, "anyarray"),
443            Self::AnyStruct => write!(f, "anystruct"),
444            Self::AnyMap => write!(f, "anymap"),
445        }
446    }
447}
448
449impl SigDataType {
450    /// Returns true if the data type matches.
451    pub fn matches(&self, dt: &DataType) -> bool {
452        match self {
453            Self::Exact(ty) => ty == dt,
454            Self::Any => true,
455            Self::AnyArray => dt.is_array(),
456            Self::AnyStruct => dt.is_struct(),
457            Self::AnyMap => dt.is_map(),
458        }
459    }
460
461    /// Returns the exact data type.
462    pub fn as_exact(&self) -> &DataType {
463        match self {
464            Self::Exact(ty) => ty,
465            t => panic!("expected data type, but got: {t}"),
466        }
467    }
468
469    /// Returns true if the data type is exact.
470    pub fn is_exact(&self) -> bool {
471        matches!(self, Self::Exact(_))
472    }
473}
474
475#[derive(Clone)]
476pub enum FuncBuilder {
477    Scalar(fn(return_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression>),
478    Table(
479        fn(
480            return_type: DataType,
481            chunk_size: usize,
482            children: Vec<BoxedExpression>,
483        ) -> Result<BoxedTableFunction>,
484    ),
485    // An aggregate function may contain both or either one of retractable and append-only versions.
486    Aggregate {
487        retractable: Option<fn(agg: &AggCall) -> Result<BoxedAggregateFunction>>,
488        append_only: Option<fn(agg: &AggCall) -> Result<BoxedAggregateFunction>>,
489        /// The state type of the retractable aggregate function.
490        /// `None` means equal to the return type.
491        retractable_state_type: Option<DataType>,
492        /// The state type of the append-only aggregate function.
493        /// `None` means equal to the return type.
494        append_only_state_type: Option<DataType>,
495    },
496    Udf,
497}
498
499/// A static distributed slice of functions defined by `#[function]`.
500#[linkme::distributed_slice]
501pub static FUNCTIONS: [fn() -> FuncSign];