risingwave_expr/sig/udf.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! UDF implementation interface.
16//!
17//! To support a new language or runtime for UDF, implement the interface in this module.
18//!
19//! See expr/impl/src/udf for the implementations.
20
21use anyhow::{Context, Result, bail};
22use educe::Educe;
23use enum_as_inner::EnumAsInner;
24use futures::stream::BoxStream;
25use risingwave_common::array::arrow::arrow_array_udf::{ArrayRef, BooleanArray, RecordBatch};
26use risingwave_common::types::DataType;
27
28/// The global registry of UDF implementations.
29///
30/// To register a new UDF implementation:
31///
32/// ```ignore
33/// #[linkme::distributed_slice(UDF_IMPLS)]
34/// static MY_UDF_LANGUAGE: UdfImplDescriptor = UdfImplDescriptor {...};
35/// ```
36#[linkme::distributed_slice]
37pub static UDF_IMPLS: [UdfImplDescriptor];
38
39/// Find a UDF implementation by language.
40pub fn find_udf_impl(
41 language: &str,
42 runtime: Option<&str>,
43 link: Option<&str>,
44) -> Result<&'static UdfImplDescriptor> {
45 let mut impls = UDF_IMPLS
46 .iter()
47 .filter(|desc| (desc.match_fn)(language, runtime, link));
48 let impl_ = impls.next().context(
49 "language not found.\nHINT: UDF feature flag may not be enabled during compilation",
50 )?;
51 if impls.next().is_some() {
52 bail!("multiple UDF implementations found for language: {language}");
53 }
54 Ok(impl_)
55}
56
57/// UDF implementation descriptor.
58///
59/// Every UDF implementation should provide 3 functions:
60pub struct UdfImplDescriptor {
61 /// Returns if a function matches the implementation.
62 ///
63 /// This function is used to determine which implementation to use for a UDF.
64 pub match_fn: fn(language: &str, runtime: Option<&str>, link: Option<&str>) -> bool,
65
66 /// Creates a function from options.
67 ///
68 /// This function will be called when `create function` statement is executed on the frontend.
69 pub create_fn: fn(opts: CreateOptions<'_>) -> Result<CreateFunctionOutput>,
70
71 /// Builds UDF runtime from verified options.
72 ///
73 /// This function will be called before the UDF is executed on the backend.
74 pub build_fn: fn(opts: BuildOptions<'_>) -> Result<Box<dyn UdfImpl>>,
75}
76
77/// Options for creating a function.
78///
79/// These information are parsed from `CREATE FUNCTION` statement.
80/// Implementations should verify the options and return a `CreateFunctionOutput` in `create_fn`.
81pub struct CreateOptions<'a> {
82 pub kind: UdfKind,
83 /// The function name registered in RisingWave.
84 pub name: &'a str,
85 pub arg_names: &'a [String],
86 pub arg_types: &'a [DataType],
87 pub return_type: &'a DataType,
88 /// The function name on the remote side / in the source code, currently only used for external UDF.
89 pub as_: Option<&'a str>,
90 pub using_link: Option<&'a str>,
91 pub using_base64_decoded: Option<&'a [u8]>,
92}
93
94/// Output of creating a function.
95pub struct CreateFunctionOutput {
96 /// The name for identifying the function in the UDF runtime.
97 pub name_in_runtime: String,
98 pub body: Option<String>,
99 pub compressed_binary: Option<Vec<u8>>,
100}
101
102/// Options for building a UDF runtime.
103#[derive(Educe)]
104#[educe(Debug)]
105pub struct BuildOptions<'a> {
106 pub kind: UdfKind,
107 pub body: Option<&'a str>,
108 #[educe(Debug(ignore))]
109 pub compressed_binary: Option<&'a [u8]>,
110 pub link: Option<&'a str>,
111 pub name_in_runtime: &'a str,
112 pub arg_names: &'a [String],
113 pub arg_types: &'a [DataType],
114 pub return_type: &'a DataType,
115 pub always_retry_on_network_error: bool,
116 pub language: &'a str,
117 pub is_async: Option<bool>,
118 pub is_batched: Option<bool>,
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumAsInner)]
122pub enum UdfKind {
123 Scalar,
124 Table,
125 Aggregate,
126}
127
128/// UDF implementation.
129#[async_trait::async_trait]
130pub trait UdfImpl: std::fmt::Debug + Send + Sync {
131 /// Call the scalar function.
132 async fn call(&self, input: &RecordBatch) -> Result<RecordBatch>;
133
134 /// Call the table function.
135 async fn call_table_function<'a>(
136 &'a self,
137 input: &'a RecordBatch,
138 ) -> Result<BoxStream<'a, Result<RecordBatch>>>;
139
140 /// For aggregate function, create the initial state.
141 async fn call_agg_create_state(&self) -> Result<ArrayRef> {
142 bail!("aggregate function is not supported");
143 }
144
145 /// For aggregate function, accumulate or retract the state.
146 async fn call_agg_accumulate_or_retract(
147 &self,
148 _state: &ArrayRef,
149 _ops: &BooleanArray,
150 _input: &RecordBatch,
151 ) -> Result<ArrayRef> {
152 bail!("aggregate function is not supported");
153 }
154
155 /// For aggregate function, get aggregate result from the state.
156 async fn call_agg_finish(&self, _state: &ArrayRef) -> Result<ArrayRef> {
157 bail!("aggregate function is not supported");
158 }
159
160 /// Whether the UDF talks in legacy mode.
161 ///
162 /// If true, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types.
163 /// Otherwise, they are mapped to Arrow extension types.
164 /// See <https://github.com/risingwavelabs/arrow-udf/tree/main#extension-types>.
165 fn is_legacy(&self) -> bool {
166 false
167 }
168
169 /// Return the memory size consumed by UDF runtime in bytes.
170 ///
171 /// If not available, return 0.
172 fn memory_usage(&self) -> usize {
173 0
174 }
175}