risingwave_expr/expr/
expr_udf.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::{Arc, LazyLock};
17
18use anyhow::Context;
19use await_tree::InstrumentAwait;
20use prometheus::{Registry, exponential_buckets};
21use risingwave_common::array::arrow::arrow_schema_udf::{Fields, Schema, SchemaRef};
22use risingwave_common::array::arrow::{UdfArrowConvert, UdfFromArrow, UdfToArrow};
23use risingwave_common::array::{Array, ArrayRef, DataChunk};
24use risingwave_common::metrics::*;
25use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY;
26use risingwave_common::row::OwnedRow;
27use risingwave_common::types::{DataType, Datum};
28use risingwave_expr::expr_context::FRAGMENT_ID;
29use risingwave_pb::expr::ExprNode;
30
31use super::{BoxedExpression, Build};
32use crate::expr::Expression;
33use crate::sig::{BuildOptions, UdfImpl, UdfKind};
34use crate::{ExprError, Result, bail};
35
36#[derive(Debug)]
37pub struct UserDefinedFunction {
38    children: Vec<BoxedExpression>,
39    arg_types: Vec<DataType>,
40    return_type: DataType,
41    arg_schema: SchemaRef,
42    runtime: Box<dyn UdfImpl>,
43    arrow_convert: UdfArrowConvert,
44    span: await_tree::Span,
45    metrics: Metrics,
46}
47
48#[async_trait::async_trait]
49impl Expression for UserDefinedFunction {
50    fn return_type(&self) -> DataType {
51        self.return_type.clone()
52    }
53
54    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
55        if input.cardinality() == 0 {
56            // early return for empty input
57            let mut builder = self.return_type.create_array_builder(input.capacity());
58            builder.append_n_null(input.capacity());
59            return Ok(builder.finish().into_ref());
60        }
61        let mut columns = Vec::with_capacity(self.children.len());
62        for child in &self.children {
63            let array = child.eval(input).await?;
64            columns.push(array);
65        }
66        let chunk = DataChunk::new(columns, input.visibility().clone());
67        self.eval_inner(&chunk).await
68    }
69
70    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
71        let mut columns = Vec::with_capacity(self.children.len());
72        for child in &self.children {
73            let datum = child.eval_row(input).await?;
74            columns.push(datum);
75        }
76        let arg_row = OwnedRow::new(columns);
77        let chunk = DataChunk::from_rows(std::slice::from_ref(&arg_row), &self.arg_types);
78        let output_array = self.eval_inner(&chunk).await?;
79        Ok(output_array.to_datum())
80    }
81}
82
83impl UserDefinedFunction {
84    async fn eval_inner(&self, input: &DataChunk) -> Result<ArrayRef> {
85        // this will drop invisible rows
86        let arrow_input = self
87            .arrow_convert
88            .to_record_batch(self.arg_schema.clone(), input)?;
89
90        // metrics
91        self.metrics
92            .input_chunk_rows
93            .observe(arrow_input.num_rows() as f64);
94        self.metrics
95            .input_rows
96            .inc_by(arrow_input.num_rows() as u64);
97        self.metrics
98            .input_bytes
99            .inc_by(arrow_input.get_array_memory_size() as u64);
100        let timer = self.metrics.latency.start_timer();
101
102        let arrow_output_result = self
103            .runtime
104            .call(&arrow_input)
105            .instrument_await(self.span.clone())
106            .await;
107
108        timer.stop_and_record();
109        if arrow_output_result.is_ok() {
110            &self.metrics.success_count
111        } else {
112            &self.metrics.failure_count
113        }
114        .inc();
115        // update memory usage
116        self.metrics
117            .memory_usage_bytes
118            .set(self.runtime.memory_usage() as i64);
119
120        let arrow_output = arrow_output_result?;
121
122        if arrow_output.num_rows() != input.cardinality() {
123            bail!(
124                "UDF returned {} rows, but expected {}",
125                arrow_output.num_rows(),
126                input.cardinality(),
127            );
128        }
129
130        let output = self.arrow_convert.from_record_batch(&arrow_output)?;
131        let output = output.uncompact(input.visibility().clone());
132
133        let Some(array) = output.columns().first() else {
134            bail!("UDF returned no columns");
135        };
136        if !array.data_type().equals_datatype(&self.return_type) {
137            bail!(
138                "UDF returned {:?}, but expected {:?}",
139                array.data_type(),
140                self.return_type,
141            );
142        }
143
144        // handle optional error column
145        if let Some(errors) = output.columns().get(1) {
146            if errors.data_type() != DataType::Varchar {
147                bail!(
148                    "UDF returned errors column with invalid type: {:?}",
149                    errors.data_type()
150                );
151            }
152            let errors = errors
153                .as_utf8()
154                .iter()
155                .filter_map(|msg| msg.map(|s| ExprError::Custom(s.into())))
156                .collect();
157            return Err(crate::ExprError::Multiple(array.clone(), errors));
158        }
159
160        Ok(array.clone())
161    }
162}
163
164impl Build for UserDefinedFunction {
165    fn build(
166        prost: &ExprNode,
167        build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
168    ) -> Result<Self> {
169        let return_type = DataType::from(prost.get_return_type().unwrap());
170        let udf = prost.get_rex_node().unwrap().as_udf().unwrap();
171        let name = udf.get_name();
172        let arg_types = udf.arg_types.iter().map(|t| t.into()).collect::<Vec<_>>();
173
174        let language = udf.language.as_str();
175        let runtime = udf.runtime.as_deref();
176        let link = udf.link.as_deref();
177
178        let name_in_runtime = udf
179            .name_in_runtime()
180            .expect("SQL UDF won't get here, other UDFs must have `name_in_runtime`");
181
182        // lookup UDF builder
183        let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn;
184        let runtime = build_fn(BuildOptions {
185            kind: UdfKind::Scalar,
186            body: udf.body.as_deref(),
187            compressed_binary: udf.compressed_binary.as_deref(),
188            link: udf.link.as_deref(),
189            name_in_runtime,
190            arg_names: &udf.arg_names,
191            arg_types: &arg_types,
192            return_type: &return_type,
193            always_retry_on_network_error: udf.always_retry_on_network_error,
194            language,
195            is_async: udf.is_async,
196            is_batched: udf.is_batched,
197        })
198        .context("failed to build UDF runtime")?;
199
200        let arrow_convert = UdfArrowConvert {
201            legacy: runtime.is_legacy(),
202        };
203
204        let arg_schema = Arc::new(Schema::new(
205            udf.arg_types
206                .iter()
207                .map(|t| arrow_convert.to_arrow_field("", &DataType::from(t)))
208                .try_collect::<Fields>()?,
209        ));
210
211        let metrics = GLOBAL_METRICS.with_label_values(
212            link.unwrap_or(""),
213            language,
214            name,
215            // batch query does not have a fragment_id
216            &FRAGMENT_ID::try_with(ToOwned::to_owned)
217                .unwrap_or(0)
218                .to_string(),
219        );
220
221        Ok(Self {
222            children: udf.children.iter().map(build_child).try_collect()?,
223            arg_types,
224            return_type,
225            arg_schema,
226            runtime,
227            arrow_convert,
228            span: await_tree::span!("udf_call({})", name),
229            metrics,
230        })
231    }
232}
233
234/// Monitor metrics for UDF.
235#[derive(Debug, Clone)]
236struct MetricsVec {
237    /// Number of successful UDF calls.
238    success_count: LabelGuardedIntCounterVec<4>,
239    /// Number of failed UDF calls.
240    failure_count: LabelGuardedIntCounterVec<4>,
241    /// Total number of retried UDF calls.
242    retry_count: LabelGuardedIntCounterVec<4>,
243    /// Input chunk rows of UDF calls.
244    input_chunk_rows: LabelGuardedHistogramVec<4>,
245    /// The latency of UDF calls in seconds.
246    latency: LabelGuardedHistogramVec<4>,
247    /// Total number of input rows of UDF calls.
248    input_rows: LabelGuardedIntCounterVec<4>,
249    /// Total number of input bytes of UDF calls.
250    input_bytes: LabelGuardedIntCounterVec<4>,
251    /// Total memory usage of UDF runtime in bytes.
252    memory_usage_bytes: LabelGuardedIntGaugeVec<5>,
253}
254
255/// Monitor metrics for UDF.
256#[derive(Debug, Clone)]
257struct Metrics {
258    /// Number of successful UDF calls.
259    success_count: LabelGuardedIntCounter<4>,
260    /// Number of failed UDF calls.
261    failure_count: LabelGuardedIntCounter<4>,
262    /// Total number of retried UDF calls.
263    #[allow(dead_code)]
264    retry_count: LabelGuardedIntCounter<4>,
265    /// Input chunk rows of UDF calls.
266    input_chunk_rows: LabelGuardedHistogram<4>,
267    /// The latency of UDF calls in seconds.
268    latency: LabelGuardedHistogram<4>,
269    /// Total number of input rows of UDF calls.
270    input_rows: LabelGuardedIntCounter<4>,
271    /// Total number of input bytes of UDF calls.
272    input_bytes: LabelGuardedIntCounter<4>,
273    /// Total memory usage of UDF runtime in bytes.
274    memory_usage_bytes: LabelGuardedIntGauge<5>,
275}
276
277/// Global UDF metrics.
278static GLOBAL_METRICS: LazyLock<MetricsVec> =
279    LazyLock::new(|| MetricsVec::new(&GLOBAL_METRICS_REGISTRY));
280
281impl MetricsVec {
282    fn new(registry: &Registry) -> Self {
283        let labels = &["link", "language", "name", "fragment_id"];
284        let labels5 = &["link", "language", "name", "fragment_id", "instance_id"];
285        let success_count = register_guarded_int_counter_vec_with_registry!(
286            "udf_success_count",
287            "Total number of successful UDF calls",
288            labels,
289            registry
290        )
291        .unwrap();
292        let failure_count = register_guarded_int_counter_vec_with_registry!(
293            "udf_failure_count",
294            "Total number of failed UDF calls",
295            labels,
296            registry
297        )
298        .unwrap();
299        let retry_count = register_guarded_int_counter_vec_with_registry!(
300            "udf_retry_count",
301            "Total number of retried UDF calls",
302            labels,
303            registry
304        )
305        .unwrap();
306        let input_chunk_rows = register_guarded_histogram_vec_with_registry!(
307            "udf_input_chunk_rows",
308            "Input chunk rows of UDF calls",
309            labels,
310            exponential_buckets(1.0, 2.0, 10).unwrap(), // 1 to 1024
311            registry
312        )
313        .unwrap();
314        let latency = register_guarded_histogram_vec_with_registry!(
315            "udf_latency",
316            "The latency(s) of UDF calls",
317            labels,
318            exponential_buckets(0.000001, 2.0, 30).unwrap(), // 1us to 1000s
319            registry
320        )
321        .unwrap();
322        let input_rows = register_guarded_int_counter_vec_with_registry!(
323            "udf_input_rows",
324            "Total number of input rows of UDF calls",
325            labels,
326            registry
327        )
328        .unwrap();
329        let input_bytes = register_guarded_int_counter_vec_with_registry!(
330            "udf_input_bytes",
331            "Total number of input bytes of UDF calls",
332            labels,
333            registry
334        )
335        .unwrap();
336        let memory_usage_bytes = register_guarded_int_gauge_vec_with_registry!(
337            "udf_memory_usage",
338            "Total memory usage of UDF runtime in bytes",
339            labels5,
340            registry
341        )
342        .unwrap();
343
344        MetricsVec {
345            success_count,
346            failure_count,
347            retry_count,
348            input_chunk_rows,
349            latency,
350            input_rows,
351            input_bytes,
352            memory_usage_bytes,
353        }
354    }
355
356    fn with_label_values(
357        &self,
358        link: &str,
359        language: &str,
360        name: &str,
361        fragment_id: &str,
362    ) -> Metrics {
363        // generate an unique id for each instance
364        static NEXT_INSTANCE_ID: AtomicU64 = AtomicU64::new(0);
365        let instance_id = NEXT_INSTANCE_ID.fetch_add(1, Ordering::Relaxed).to_string();
366
367        let labels = &[link, language, name, fragment_id];
368        let labels5 = &[link, language, name, fragment_id, &instance_id];
369
370        Metrics {
371            success_count: self.success_count.with_guarded_label_values(labels),
372            failure_count: self.failure_count.with_guarded_label_values(labels),
373            retry_count: self.retry_count.with_guarded_label_values(labels),
374            input_chunk_rows: self.input_chunk_rows.with_guarded_label_values(labels),
375            latency: self.latency.with_guarded_label_values(labels),
376            input_rows: self.input_rows.with_guarded_label_values(labels),
377            input_bytes: self.input_bytes.with_guarded_label_values(labels),
378            memory_usage_bytes: self.memory_usage_bytes.with_guarded_label_values(labels5),
379        }
380    }
381}