risingwave_expr/sig/
udf.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! UDF implementation interface.
16//!
17//! To support a new language or runtime for UDF, implement the interface in this module.
18//!
19//! See expr/impl/src/udf for the implementations.
20
21use anyhow::{Context, Result, bail};
22use educe::Educe;
23use enum_as_inner::EnumAsInner;
24use futures::stream::BoxStream;
25use risingwave_common::array::arrow::arrow_array_udf::{ArrayRef, BooleanArray, RecordBatch};
26use risingwave_common::types::DataType;
27
28/// The global registry of UDF implementations.
29///
30/// To register a new UDF implementation:
31///
32/// ```ignore
33/// #[linkme::distributed_slice(UDF_IMPLS)]
34/// static MY_UDF_LANGUAGE: UdfImplDescriptor = UdfImplDescriptor {...};
35/// ```
36#[linkme::distributed_slice]
37pub static UDF_IMPLS: [UdfImplDescriptor];
38
39/// Find a UDF implementation by language.
40pub fn find_udf_impl(
41    language: &str,
42    runtime: Option<&str>,
43    link: Option<&str>,
44) -> Result<&'static UdfImplDescriptor> {
45    let mut impls = UDF_IMPLS
46        .iter()
47        .filter(|desc| (desc.match_fn)(language, runtime, link));
48    let impl_ = impls.next().context(
49        "language not found.\nHINT: UDF feature flag may not be enabled during compilation",
50    )?;
51    if impls.next().is_some() {
52        bail!("multiple UDF implementations found for language: {language}");
53    }
54    Ok(impl_)
55}
56
57/// UDF implementation descriptor.
58///
59/// Every UDF implementation should provide 3 functions:
60pub struct UdfImplDescriptor {
61    /// Returns if a function matches the implementation.
62    ///
63    /// This function is used to determine which implementation to use for a UDF.
64    pub match_fn: fn(language: &str, runtime: Option<&str>, link: Option<&str>) -> bool,
65
66    /// Creates a function from options.
67    ///
68    /// This function will be called when `create function` statement is executed on the frontend.
69    pub create_fn: fn(opts: CreateOptions<'_>) -> Result<CreateFunctionOutput>,
70
71    /// Builds UDF runtime from verified options.
72    ///
73    /// This function will be called before the UDF is executed on the backend.
74    pub build_fn: fn(opts: BuildOptions<'_>) -> Result<Box<dyn UdfImpl>>,
75}
76
77/// Options for creating a function.
78///
79/// These information are parsed from `CREATE FUNCTION` statement.
80/// Implementations should verify the options and return a `CreateFunctionOutput` in `create_fn`.
81pub struct CreateOptions<'a> {
82    pub kind: UdfKind,
83    /// The function name registered in RisingWave.
84    pub name: &'a str,
85    pub arg_names: &'a [String],
86    pub arg_types: &'a [DataType],
87    pub return_type: &'a DataType,
88    /// The function name on the remote side / in the source code, currently only used for external UDF.
89    pub as_: Option<&'a str>,
90    pub using_link: Option<&'a str>,
91    pub using_base64_decoded: Option<&'a [u8]>,
92}
93
94/// Output of creating a function.
95pub struct CreateFunctionOutput {
96    /// The name for identifying the function in the UDF runtime.
97    pub name_in_runtime: String,
98    pub body: Option<String>,
99    pub compressed_binary: Option<Vec<u8>>,
100}
101
102/// Options for building a UDF runtime.
103#[derive(Educe)]
104#[educe(Debug)]
105pub struct BuildOptions<'a> {
106    pub kind: UdfKind,
107    pub body: Option<&'a str>,
108    #[educe(Debug(ignore))]
109    pub compressed_binary: Option<&'a [u8]>,
110    pub link: Option<&'a str>,
111    pub name_in_runtime: &'a str,
112    pub arg_names: &'a [String],
113    pub arg_types: &'a [DataType],
114    pub return_type: &'a DataType,
115    pub always_retry_on_network_error: bool,
116    pub language: &'a str,
117    pub is_async: Option<bool>,
118    pub is_batched: Option<bool>,
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumAsInner)]
122pub enum UdfKind {
123    Scalar,
124    Table,
125    Aggregate,
126}
127
128/// UDF implementation.
129#[async_trait::async_trait]
130pub trait UdfImpl: std::fmt::Debug + Send + Sync {
131    /// Call the scalar function.
132    async fn call(&self, input: &RecordBatch) -> Result<RecordBatch>;
133
134    /// Call the table function.
135    async fn call_table_function<'a>(
136        &'a self,
137        input: &'a RecordBatch,
138    ) -> Result<BoxStream<'a, Result<RecordBatch>>>;
139
140    /// For aggregate function, create the initial state.
141    async fn call_agg_create_state(&self) -> Result<ArrayRef> {
142        bail!("aggregate function is not supported");
143    }
144
145    /// For aggregate function, accumulate or retract the state.
146    async fn call_agg_accumulate_or_retract(
147        &self,
148        _state: &ArrayRef,
149        _ops: &BooleanArray,
150        _input: &RecordBatch,
151    ) -> Result<ArrayRef> {
152        bail!("aggregate function is not supported");
153    }
154
155    /// For aggregate function, get aggregate result from the state.
156    async fn call_agg_finish(&self, _state: &ArrayRef) -> Result<ArrayRef> {
157        bail!("aggregate function is not supported");
158    }
159
160    /// Whether the UDF talks in legacy mode.
161    ///
162    /// If true, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types.
163    /// Otherwise, they are mapped to Arrow extension types.
164    /// See <https://github.com/risingwavelabs/arrow-udf/tree/main#extension-types>.
165    fn is_legacy(&self) -> bool {
166        false
167    }
168
169    /// Return the memory size consumed by UDF runtime in bytes.
170    ///
171    /// If not available, return 0.
172    fn memory_usage(&self) -> usize {
173        0
174    }
175}