risingwave_expr/expr/
expr_udf.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::{Arc, LazyLock};
17
18use anyhow::Context;
19use await_tree::InstrumentAwait;
20use prometheus::{Registry, exponential_buckets};
21use risingwave_common::array::arrow::arrow_schema_udf::{Fields, Schema, SchemaRef};
22use risingwave_common::array::arrow::{UdfArrowConvert, UdfFromArrow, UdfToArrow};
23use risingwave_common::array::{Array, ArrayRef, DataChunk};
24use risingwave_common::metrics::*;
25use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY;
26use risingwave_common::row::OwnedRow;
27use risingwave_common::types::{DataType, Datum};
28use risingwave_expr::expr_context::FRAGMENT_ID;
29use risingwave_pb::expr::ExprNode;
30
31use super::{BoxedExpression, Build};
32use crate::expr::Expression;
33use crate::sig::{BuildOptions, UdfImpl, UdfKind};
34use crate::{ExprError, Result, bail};
35
36#[derive(Debug)]
37pub struct UserDefinedFunction {
38    children: Vec<BoxedExpression>,
39    arg_types: Vec<DataType>,
40    return_type: DataType,
41    arg_schema: SchemaRef,
42    runtime: Box<dyn UdfImpl>,
43    arrow_convert: UdfArrowConvert,
44    span: await_tree::Span,
45    metrics: Metrics,
46}
47
48#[async_trait::async_trait]
49impl Expression for UserDefinedFunction {
50    fn return_type(&self) -> DataType {
51        self.return_type.clone()
52    }
53
54    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
55        if input.cardinality() == 0 {
56            // early return for empty input
57            let mut builder = self.return_type.create_array_builder(input.capacity());
58            builder.append_n_null(input.capacity());
59            return Ok(builder.finish().into_ref());
60        }
61        let mut columns = Vec::with_capacity(self.children.len());
62        for child in &self.children {
63            let array = child.eval(input).await?;
64            columns.push(array);
65        }
66        let chunk = DataChunk::new(columns, input.visibility().clone());
67        self.eval_inner(&chunk).await
68    }
69
70    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
71        let mut columns = Vec::with_capacity(self.children.len());
72        for child in &self.children {
73            let datum = child.eval_row(input).await?;
74            columns.push(datum);
75        }
76        let arg_row = OwnedRow::new(columns);
77        let chunk = DataChunk::from_rows(std::slice::from_ref(&arg_row), &self.arg_types);
78        let output_array = self.eval_inner(&chunk).await?;
79        Ok(output_array.to_datum())
80    }
81}
82
83impl UserDefinedFunction {
84    async fn eval_inner(&self, input: &DataChunk) -> Result<ArrayRef> {
85        // this will drop invisible rows
86        let arrow_input = self
87            .arrow_convert
88            .to_record_batch(self.arg_schema.clone(), input)?;
89
90        // metrics
91        self.metrics
92            .input_chunk_rows
93            .observe(arrow_input.num_rows() as f64);
94        self.metrics
95            .input_rows
96            .inc_by(arrow_input.num_rows() as u64);
97        self.metrics
98            .input_bytes
99            .inc_by(arrow_input.get_array_memory_size() as u64);
100        let timer = self.metrics.latency.start_timer();
101
102        let arrow_output_result = self
103            .runtime
104            .call(&arrow_input)
105            .instrument_await(self.span.clone())
106            .await;
107
108        timer.stop_and_record();
109        if arrow_output_result.is_ok() {
110            &self.metrics.success_count
111        } else {
112            &self.metrics.failure_count
113        }
114        .inc();
115        // update memory usage
116        self.metrics
117            .memory_usage_bytes
118            .set(self.runtime.memory_usage() as i64);
119
120        let arrow_output = arrow_output_result?;
121
122        if arrow_output.num_rows() != input.cardinality() {
123            bail!(
124                "UDF returned {} rows, but expected {}",
125                arrow_output.num_rows(),
126                input.cardinality(),
127            );
128        }
129
130        let output = self.arrow_convert.from_record_batch(&arrow_output)?;
131        let output = output.expand_vis(input.visibility().clone());
132
133        let Some(array) = output.columns().first() else {
134            bail!("UDF returned no columns");
135        };
136        if !array.data_type().equals_datatype(&self.return_type) {
137            bail!(
138                "UDF returned {:?}, but expected {:?}",
139                array.data_type(),
140                self.return_type,
141            );
142        }
143
144        // handle optional error column
145        if let Some(errors) = output.columns().get(1) {
146            if errors.data_type() != DataType::Varchar {
147                bail!(
148                    "UDF returned errors column with invalid type: {:?}",
149                    errors.data_type()
150                );
151            }
152            let errors = errors
153                .as_utf8()
154                .iter()
155                .filter_map(|msg| msg.map(|s| ExprError::Custom(s.into())))
156                .collect();
157            return Err(crate::ExprError::Multiple(array.clone(), errors));
158        }
159
160        Ok(array.clone())
161    }
162}
163
164impl Build for UserDefinedFunction {
165    fn build(
166        prost: &ExprNode,
167        build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
168    ) -> Result<Self> {
169        let return_type = DataType::from(prost.get_return_type().unwrap());
170        let udf = prost.get_rex_node().unwrap().as_udf().unwrap();
171        let name = udf.get_name();
172        let arg_types = udf.arg_types.iter().map(|t| t.into()).collect::<Vec<_>>();
173
174        let language = udf.language.as_str();
175        let runtime = udf.runtime.as_deref();
176        let link = udf.link.as_deref();
177
178        let name_in_runtime = udf
179            .name_in_runtime()
180            .expect("SQL UDF won't get here, other UDFs must have `name_in_runtime`");
181
182        // lookup UDF builder
183        let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn;
184        let runtime = build_fn(BuildOptions {
185            kind: UdfKind::Scalar,
186            body: udf.body.as_deref(),
187            compressed_binary: udf.compressed_binary.as_deref(),
188            link: udf.link.as_deref(),
189            name_in_runtime,
190            arg_names: &udf.arg_names,
191            arg_types: &arg_types,
192            return_type: &return_type,
193            always_retry_on_network_error: udf.always_retry_on_network_error,
194            language,
195            is_async: udf.is_async,
196            is_batched: udf.is_batched,
197        })
198        .context("failed to build UDF runtime")?;
199
200        let arrow_convert = UdfArrowConvert {
201            legacy: runtime.is_legacy(),
202        };
203
204        let arg_schema = Arc::new(Schema::new(
205            udf.arg_types
206                .iter()
207                .map(|t| arrow_convert.to_arrow_field("", &DataType::from(t)))
208                .try_collect::<Fields>()?,
209        ));
210
211        let metrics = GLOBAL_METRICS.with_label_values(
212            link.unwrap_or(""),
213            language,
214            name,
215            // batch query does not have a fragment_id
216            &FRAGMENT_ID::try_with(ToOwned::to_owned)
217                .unwrap_or(0.into())
218                .to_string(),
219        );
220
221        let children: Vec<BoxedExpression> = udf.children.iter().map(build_child).try_collect()?;
222
223        Ok(Self {
224            children,
225            arg_types,
226            return_type,
227            arg_schema,
228            runtime,
229            arrow_convert,
230            span: await_tree::span!("udf_call({})", name),
231            metrics,
232        })
233    }
234}
235
236/// Monitor metrics for UDF.
237#[derive(Debug, Clone)]
238struct MetricsVec {
239    /// Number of successful UDF calls.
240    success_count: LabelGuardedIntCounterVec,
241    /// Number of failed UDF calls.
242    failure_count: LabelGuardedIntCounterVec,
243    /// Total number of retried UDF calls.
244    retry_count: LabelGuardedIntCounterVec,
245    /// Input chunk rows of UDF calls.
246    input_chunk_rows: LabelGuardedHistogramVec,
247    /// The latency of UDF calls in seconds.
248    latency: LabelGuardedHistogramVec,
249    /// Total number of input rows of UDF calls.
250    input_rows: LabelGuardedIntCounterVec,
251    /// Total number of input bytes of UDF calls.
252    input_bytes: LabelGuardedIntCounterVec,
253    /// Total memory usage of UDF runtime in bytes.
254    memory_usage_bytes: LabelGuardedIntGaugeVec,
255}
256
257/// Monitor metrics for UDF.
258#[derive(Debug, Clone)]
259struct Metrics {
260    /// Number of successful UDF calls.
261    success_count: LabelGuardedIntCounter,
262    /// Number of failed UDF calls.
263    failure_count: LabelGuardedIntCounter,
264    /// Total number of retried UDF calls.
265    #[allow(dead_code)]
266    retry_count: LabelGuardedIntCounter,
267    /// Input chunk rows of UDF calls.
268    input_chunk_rows: LabelGuardedHistogram,
269    /// The latency of UDF calls in seconds.
270    latency: LabelGuardedHistogram,
271    /// Total number of input rows of UDF calls.
272    input_rows: LabelGuardedIntCounter,
273    /// Total number of input bytes of UDF calls.
274    input_bytes: LabelGuardedIntCounter,
275    /// Total memory usage of UDF runtime in bytes.
276    memory_usage_bytes: LabelGuardedIntGauge,
277}
278
279/// Global UDF metrics.
280static GLOBAL_METRICS: LazyLock<MetricsVec> =
281    LazyLock::new(|| MetricsVec::new(&GLOBAL_METRICS_REGISTRY));
282
283impl MetricsVec {
284    fn new(registry: &Registry) -> Self {
285        let labels = &["link", "language", "name", "fragment_id"];
286        let labels5 = &["link", "language", "name", "fragment_id", "instance_id"];
287        let success_count = register_guarded_int_counter_vec_with_registry!(
288            "udf_success_count",
289            "Total number of successful UDF calls",
290            labels,
291            registry
292        )
293        .unwrap();
294        let failure_count = register_guarded_int_counter_vec_with_registry!(
295            "udf_failure_count",
296            "Total number of failed UDF calls",
297            labels,
298            registry
299        )
300        .unwrap();
301        let retry_count = register_guarded_int_counter_vec_with_registry!(
302            "udf_retry_count",
303            "Total number of retried UDF calls",
304            labels,
305            registry
306        )
307        .unwrap();
308        let input_chunk_rows = register_guarded_histogram_vec_with_registry!(
309            "udf_input_chunk_rows",
310            "Input chunk rows of UDF calls",
311            labels,
312            exponential_buckets(1.0, 2.0, 10).unwrap(), // 1 to 1024
313            registry
314        )
315        .unwrap();
316        let latency = register_guarded_histogram_vec_with_registry!(
317            "udf_latency",
318            "The latency(s) of UDF calls",
319            labels,
320            exponential_buckets(0.000001, 2.0, 30).unwrap(), // 1us to 1000s
321            registry
322        )
323        .unwrap();
324        let input_rows = register_guarded_int_counter_vec_with_registry!(
325            "udf_input_rows",
326            "Total number of input rows of UDF calls",
327            labels,
328            registry
329        )
330        .unwrap();
331        let input_bytes = register_guarded_int_counter_vec_with_registry!(
332            "udf_input_bytes",
333            "Total number of input bytes of UDF calls",
334            labels,
335            registry
336        )
337        .unwrap();
338        let memory_usage_bytes = register_guarded_int_gauge_vec_with_registry!(
339            "udf_memory_usage",
340            "Total memory usage of UDF runtime in bytes",
341            labels5,
342            registry
343        )
344        .unwrap();
345
346        MetricsVec {
347            success_count,
348            failure_count,
349            retry_count,
350            input_chunk_rows,
351            latency,
352            input_rows,
353            input_bytes,
354            memory_usage_bytes,
355        }
356    }
357
358    fn with_label_values(
359        &self,
360        link: &str,
361        language: &str,
362        name: &str,
363        fragment_id: &str,
364    ) -> Metrics {
365        // generate an unique id for each instance
366        static NEXT_INSTANCE_ID: AtomicU64 = AtomicU64::new(0);
367        let instance_id = NEXT_INSTANCE_ID.fetch_add(1, Ordering::Relaxed).to_string();
368
369        let labels = &[link, language, name, fragment_id];
370        let labels5 = &[link, language, name, fragment_id, &instance_id];
371
372        Metrics {
373            success_count: self.success_count.with_guarded_label_values(labels),
374            failure_count: self.failure_count.with_guarded_label_values(labels),
375            retry_count: self.retry_count.with_guarded_label_values(labels),
376            input_chunk_rows: self.input_chunk_rows.with_guarded_label_values(labels),
377            latency: self.latency.with_guarded_label_values(labels),
378            input_rows: self.input_rows.with_guarded_label_values(labels),
379            input_bytes: self.input_bytes.with_guarded_label_values(labels),
380            memory_usage_bytes: self.memory_usage_bytes.with_guarded_label_values(labels5),
381        }
382    }
383}