1use std::borrow::Cow;
18use std::collections::HashMap;
19use std::fmt;
20use std::sync::LazyLock;
21
22use itertools::Itertools;
23use risingwave_common::types::DataType;
24use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
25use risingwave_pb::expr::expr_node::PbType as ScalarFunctionType;
26use risingwave_pb::expr::table_function::PbType as TableFunctionType;
27
28use crate::ExprError;
29use crate::aggregate::{AggCall, BoxedAggregateFunction};
30use crate::error::Result;
31use crate::expr::BoxedExpression;
32use crate::table_function::BoxedTableFunction;
33
34mod udf;
35
36pub use self::udf::*;
37
38pub static FUNCTION_REGISTRY: LazyLock<FunctionRegistry> = LazyLock::new(|| {
40 let mut map = FunctionRegistry::default();
41 tracing::info!("found {} functions", FUNCTIONS.len());
42 for f in FUNCTIONS {
43 map.insert(f());
44 }
45 map
46});
47
48#[derive(Default, Clone, Debug)]
50pub struct FunctionRegistry(HashMap<FuncName, Vec<FuncSign>>);
51
52impl FunctionRegistry {
53 pub fn insert(&mut self, sig: FuncSign) {
55 let list = self.0.entry(sig.name.clone()).or_default();
56 if sig.is_aggregate() {
57 if let Some(existing) = list
59 .iter_mut()
60 .find(|d| d.inputs_type == sig.inputs_type && d.ret_type == sig.ret_type)
61 {
62 let (
63 FuncBuilder::Aggregate {
64 retractable,
65 append_only,
66 retractable_state_type,
67 append_only_state_type,
68 },
69 FuncBuilder::Aggregate {
70 retractable: r1,
71 append_only: a1,
72 retractable_state_type: rs1,
73 append_only_state_type: as1,
74 },
75 ) = (&mut existing.build, sig.build)
76 else {
77 panic!("expected aggregate function")
78 };
79 if let Some(f) = r1 {
80 *retractable = Some(f);
81 *retractable_state_type = rs1;
82 }
83 if let Some(f) = a1 {
84 *append_only = Some(f);
85 *append_only_state_type = as1;
86 }
87 return;
88 }
89 }
90 list.push(sig);
91 }
92
93 pub fn remove(&mut self, sig: FuncSign) -> Option<FuncSign> {
95 let pos = self
96 .0
97 .get_mut(&sig.name)?
98 .iter()
99 .positions(|s| s.inputs_type == sig.inputs_type && s.ret_type == sig.ret_type)
100 .rev()
101 .collect_vec();
102 let mut ret = None;
103 for p in pos {
104 ret = Some(self.0.get_mut(&sig.name)?.swap_remove(p));
105 }
106 ret
107 }
108
109 pub fn get(
112 &self,
113 name: impl Into<FuncName>,
114 args: &[DataType],
115 ret: &DataType,
116 ) -> Result<&FuncSign, ExprError> {
117 let name = name.into();
118 let err = |candidates: &Vec<FuncSign>| {
119 ExprError::UnsupportedFunction(format!(
122 "{}({}) -> {}{}",
123 name,
124 args.iter().format(", "),
125 ret,
126 if candidates.is_empty() {
127 "".to_owned()
128 } else {
129 format!(
130 "\nHINT: Supported functions:\n{}",
131 candidates
132 .iter()
133 .map(|d| format!(
134 " {}({}) -> {}",
135 d.name,
136 d.inputs_type.iter().format(", "),
137 d.ret_type
138 ))
139 .format("\n")
140 )
141 }
142 ))
143 };
144 let v = self.0.get(&name).ok_or_else(|| err(&vec![]))?;
145 v.iter()
146 .find(|d| d.match_args_ret(args, ret))
147 .ok_or_else(|| err(v))
148 }
149
150 pub fn get_with_arg_nums(&self, name: impl Into<FuncName>, nargs: usize) -> Vec<&FuncSign> {
153 match self.0.get(&name.into()) {
154 Some(v) => v
155 .iter()
156 .filter(|d| d.match_number_of_args(nargs) && !d.deprecated)
157 .collect(),
158 None => vec![],
159 }
160 }
161
162 pub fn get_return_type(
165 &self,
166 name: impl Into<FuncName>,
167 args: &[DataType],
168 ) -> Result<DataType> {
169 let name = name.into();
170 let v = self
171 .0
172 .get(&name)
173 .ok_or_else(|| ExprError::UnsupportedFunction(name.to_string()))?;
174 let sig = v
175 .iter()
176 .find(|d| d.match_args(args) && !d.deprecated)
177 .ok_or_else(|| ExprError::UnsupportedFunction(name.to_string()))?;
178 (sig.type_infer)(args)
179 }
180
181 pub fn iter(&self) -> impl Iterator<Item = &FuncSign> {
183 self.0.values().flatten()
184 }
185
186 pub fn iter_scalars(&self) -> impl Iterator<Item = &FuncSign> {
188 self.iter().filter(|d| d.is_scalar())
189 }
190
191 pub fn iter_aggregates(&self) -> impl Iterator<Item = &FuncSign> {
193 self.iter().filter(|d| d.is_aggregate())
194 }
195}
196
197#[derive(Clone)]
199pub struct FuncSign {
200 pub name: FuncName,
202
203 pub inputs_type: Vec<SigDataType>,
205
206 pub variadic: bool,
208
209 pub ret_type: SigDataType,
211
212 pub build: FuncBuilder,
214
215 pub type_infer: fn(args: &[DataType]) -> Result<DataType>,
217
218 pub deprecated: bool,
221}
222
223impl fmt::Debug for FuncSign {
224 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225 write!(
226 f,
227 "{}({}{}) -> {}{}",
228 self.name.as_str_name().to_ascii_lowercase(),
229 self.inputs_type.iter().format(", "),
230 if self.variadic {
231 if self.inputs_type.is_empty() {
232 "..."
233 } else {
234 ", ..."
235 }
236 } else {
237 ""
238 },
239 if self.name.is_table() { "setof " } else { "" },
240 self.ret_type,
241 )?;
242 if self.deprecated {
243 write!(f, " [deprecated]")?;
244 }
245 Ok(())
246 }
247}
248
249impl FuncSign {
250 pub fn match_args(&self, args: &[DataType]) -> bool {
252 if !self.match_number_of_args(args.len()) {
253 return false;
254 }
255 #[allow(clippy::disallowed_methods)]
257 self.inputs_type
258 .iter()
259 .zip(args.iter())
260 .all(|(matcher, arg)| matcher.matches(arg))
261 }
262
263 fn match_args_ret(&self, args: &[DataType], ret: &DataType) -> bool {
265 self.match_args(args) && self.ret_type.matches(ret)
266 }
267
268 fn match_number_of_args(&self, n: usize) -> bool {
270 if self.variadic {
271 n >= self.inputs_type.len()
272 } else {
273 n == self.inputs_type.len()
274 }
275 }
276
277 pub const fn is_scalar(&self) -> bool {
279 matches!(self.name, FuncName::Scalar(_))
280 }
281
282 pub const fn is_table_function(&self) -> bool {
284 matches!(self.name, FuncName::Table(_))
285 }
286
287 pub const fn is_aggregate(&self) -> bool {
289 matches!(self.name, FuncName::Aggregate(_))
290 }
291
292 pub const fn is_append_only(&self) -> bool {
294 matches!(
295 self.build,
296 FuncBuilder::Aggregate {
297 retractable: None,
298 ..
299 }
300 )
301 }
302
303 pub const fn is_retractable(&self) -> bool {
305 matches!(
306 self.build,
307 FuncBuilder::Aggregate {
308 retractable: Some(_),
309 ..
310 }
311 )
312 }
313
314 pub fn build_scalar(
316 &self,
317 return_type: DataType,
318 children: Vec<BoxedExpression>,
319 ) -> Result<BoxedExpression> {
320 match self.build {
321 FuncBuilder::Scalar(f) => f(return_type, children),
322 _ => panic!("Expected a scalar function"),
323 }
324 }
325
326 pub fn build_table(
328 &self,
329 return_type: DataType,
330 chunk_size: usize,
331 children: Vec<BoxedExpression>,
332 ) -> Result<BoxedTableFunction> {
333 match self.build {
334 FuncBuilder::Table(f) => f(return_type, chunk_size, children),
335 _ => panic!("Expected a table function"),
336 }
337 }
338
339 pub fn build_aggregate(&self, agg: &AggCall) -> Result<BoxedAggregateFunction> {
342 match self.build {
343 FuncBuilder::Aggregate {
344 retractable,
345 append_only,
346 ..
347 } => retractable.or(append_only).unwrap()(agg),
348 _ => panic!("Expected an aggregate function"),
349 }
350 }
351}
352
353#[derive(Debug, Clone, PartialEq, Eq, Hash)]
354pub enum FuncName {
355 Scalar(ScalarFunctionType),
356 Table(TableFunctionType),
357 Aggregate(PbAggKind),
358 Udf(String),
359}
360
361impl From<ScalarFunctionType> for FuncName {
362 fn from(ty: ScalarFunctionType) -> Self {
363 Self::Scalar(ty)
364 }
365}
366
367impl From<TableFunctionType> for FuncName {
368 fn from(ty: TableFunctionType) -> Self {
369 Self::Table(ty)
370 }
371}
372
373impl From<PbAggKind> for FuncName {
374 fn from(ty: PbAggKind) -> Self {
375 Self::Aggregate(ty)
376 }
377}
378
379impl fmt::Display for FuncName {
380 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381 write!(f, "{}", self.as_str_name().to_ascii_lowercase())
382 }
383}
384
385impl FuncName {
386 pub fn as_str_name(&self) -> Cow<'static, str> {
388 match self {
389 Self::Scalar(ty) => ty.as_str_name().into(),
390 Self::Table(ty) => ty.as_str_name().into(),
391 Self::Aggregate(ty) => ty.as_str_name().into(),
392 Self::Udf(name) => name.clone().into(),
393 }
394 }
395
396 const fn is_table(&self) -> bool {
398 matches!(self, Self::Table(_))
399 }
400
401 pub fn as_scalar(&self) -> ScalarFunctionType {
402 match self {
403 Self::Scalar(ty) => *ty,
404 _ => panic!("Expected a scalar function"),
405 }
406 }
407
408 pub fn as_aggregate(&self) -> PbAggKind {
409 match self {
410 Self::Aggregate(kind) => *kind,
411 _ => panic!("Expected an aggregate function"),
412 }
413 }
414}
415
416#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
418pub enum SigDataType {
419 Exact(DataType),
421 Any,
423 AnyArray,
425 AnyStruct,
427 AnyMap,
429}
430
431impl From<DataType> for SigDataType {
432 fn from(dt: DataType) -> Self {
433 SigDataType::Exact(dt)
434 }
435}
436
437impl std::fmt::Display for SigDataType {
438 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439 match self {
440 Self::Exact(dt) => write!(f, "{}", dt),
441 Self::Any => write!(f, "any"),
442 Self::AnyArray => write!(f, "anyarray"),
443 Self::AnyStruct => write!(f, "anystruct"),
444 Self::AnyMap => write!(f, "anymap"),
445 }
446 }
447}
448
449impl SigDataType {
450 pub fn matches(&self, dt: &DataType) -> bool {
452 match self {
453 Self::Exact(ty) => ty == dt,
454 Self::Any => true,
455 Self::AnyArray => dt.is_array(),
456 Self::AnyStruct => dt.is_struct(),
457 Self::AnyMap => dt.is_map(),
458 }
459 }
460
461 pub fn as_exact(&self) -> &DataType {
463 match self {
464 Self::Exact(ty) => ty,
465 t => panic!("expected data type, but got: {t}"),
466 }
467 }
468
469 pub fn is_exact(&self) -> bool {
471 matches!(self, Self::Exact(_))
472 }
473}
474
475#[derive(Clone)]
476pub enum FuncBuilder {
477 Scalar(fn(return_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression>),
478 Table(
479 fn(
480 return_type: DataType,
481 chunk_size: usize,
482 children: Vec<BoxedExpression>,
483 ) -> Result<BoxedTableFunction>,
484 ),
485 Aggregate {
487 retractable: Option<fn(agg: &AggCall) -> Result<BoxedAggregateFunction>>,
488 append_only: Option<fn(agg: &AggCall) -> Result<BoxedAggregateFunction>>,
489 retractable_state_type: Option<DataType>,
492 append_only_state_type: Option<DataType>,
495 },
496 Udf,
497}
498
499#[linkme::distributed_slice]
501pub static FUNCTIONS: [fn() -> FuncSign];