1use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::{Arc, LazyLock};
17
18use anyhow::Context;
19use await_tree::InstrumentAwait;
20use prometheus::{Registry, exponential_buckets};
21use risingwave_common::array::arrow::arrow_schema_udf::{Fields, Schema, SchemaRef};
22use risingwave_common::array::arrow::{UdfArrowConvert, UdfFromArrow, UdfToArrow};
23use risingwave_common::array::{Array, ArrayRef, DataChunk};
24use risingwave_common::metrics::*;
25use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY;
26use risingwave_common::row::OwnedRow;
27use risingwave_common::types::{DataType, Datum};
28use risingwave_expr::expr_context::FRAGMENT_ID;
29use risingwave_pb::expr::ExprNode;
30
31use super::{BoxedExpression, Build};
32use crate::expr::Expression;
33use crate::sig::{BuildOptions, UdfImpl, UdfKind};
34use crate::{ExprError, Result, bail};
35
36#[derive(Debug)]
37pub struct UserDefinedFunction {
38 children: Vec<BoxedExpression>,
39 arg_types: Vec<DataType>,
40 return_type: DataType,
41 arg_schema: SchemaRef,
42 runtime: Box<dyn UdfImpl>,
43 arrow_convert: UdfArrowConvert,
44 span: await_tree::Span,
45 metrics: Metrics,
46}
47
48#[async_trait::async_trait]
49impl Expression for UserDefinedFunction {
50 fn return_type(&self) -> DataType {
51 self.return_type.clone()
52 }
53
54 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
55 if input.cardinality() == 0 {
56 let mut builder = self.return_type.create_array_builder(input.capacity());
58 builder.append_n_null(input.capacity());
59 return Ok(builder.finish().into_ref());
60 }
61 let mut columns = Vec::with_capacity(self.children.len());
62 for child in &self.children {
63 let array = child.eval(input).await?;
64 columns.push(array);
65 }
66 let chunk = DataChunk::new(columns, input.visibility().clone());
67 self.eval_inner(&chunk).await
68 }
69
70 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
71 let mut columns = Vec::with_capacity(self.children.len());
72 for child in &self.children {
73 let datum = child.eval_row(input).await?;
74 columns.push(datum);
75 }
76 let arg_row = OwnedRow::new(columns);
77 let chunk = DataChunk::from_rows(std::slice::from_ref(&arg_row), &self.arg_types);
78 let output_array = self.eval_inner(&chunk).await?;
79 Ok(output_array.to_datum())
80 }
81}
82
83impl UserDefinedFunction {
84 async fn eval_inner(&self, input: &DataChunk) -> Result<ArrayRef> {
85 let arrow_input = self
87 .arrow_convert
88 .to_record_batch(self.arg_schema.clone(), input)?;
89
90 self.metrics
92 .input_chunk_rows
93 .observe(arrow_input.num_rows() as f64);
94 self.metrics
95 .input_rows
96 .inc_by(arrow_input.num_rows() as u64);
97 self.metrics
98 .input_bytes
99 .inc_by(arrow_input.get_array_memory_size() as u64);
100 let timer = self.metrics.latency.start_timer();
101
102 let arrow_output_result = self
103 .runtime
104 .call(&arrow_input)
105 .instrument_await(self.span.clone())
106 .await;
107
108 timer.stop_and_record();
109 if arrow_output_result.is_ok() {
110 &self.metrics.success_count
111 } else {
112 &self.metrics.failure_count
113 }
114 .inc();
115 self.metrics
117 .memory_usage_bytes
118 .set(self.runtime.memory_usage() as i64);
119
120 let arrow_output = arrow_output_result?;
121
122 if arrow_output.num_rows() != input.cardinality() {
123 bail!(
124 "UDF returned {} rows, but expected {}",
125 arrow_output.num_rows(),
126 input.cardinality(),
127 );
128 }
129
130 let output = self.arrow_convert.from_record_batch(&arrow_output)?;
131 let output = output.uncompact(input.visibility().clone());
132
133 let Some(array) = output.columns().first() else {
134 bail!("UDF returned no columns");
135 };
136 if !array.data_type().equals_datatype(&self.return_type) {
137 bail!(
138 "UDF returned {:?}, but expected {:?}",
139 array.data_type(),
140 self.return_type,
141 );
142 }
143
144 if let Some(errors) = output.columns().get(1) {
146 if errors.data_type() != DataType::Varchar {
147 bail!(
148 "UDF returned errors column with invalid type: {:?}",
149 errors.data_type()
150 );
151 }
152 let errors = errors
153 .as_utf8()
154 .iter()
155 .filter_map(|msg| msg.map(|s| ExprError::Custom(s.into())))
156 .collect();
157 return Err(crate::ExprError::Multiple(array.clone(), errors));
158 }
159
160 Ok(array.clone())
161 }
162}
163
164impl Build for UserDefinedFunction {
165 fn build(
166 prost: &ExprNode,
167 build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
168 ) -> Result<Self> {
169 let return_type = DataType::from(prost.get_return_type().unwrap());
170 let udf = prost.get_rex_node().unwrap().as_udf().unwrap();
171 let name = udf.get_name();
172 let arg_types = udf.arg_types.iter().map(|t| t.into()).collect::<Vec<_>>();
173
174 let language = udf.language.as_str();
175 let runtime = udf.runtime.as_deref();
176 let link = udf.link.as_deref();
177
178 let name_in_runtime = udf
179 .name_in_runtime()
180 .expect("SQL UDF won't get here, other UDFs must have `name_in_runtime`");
181
182 let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn;
184 let runtime = build_fn(BuildOptions {
185 kind: UdfKind::Scalar,
186 body: udf.body.as_deref(),
187 compressed_binary: udf.compressed_binary.as_deref(),
188 link: udf.link.as_deref(),
189 name_in_runtime,
190 arg_names: &udf.arg_names,
191 arg_types: &arg_types,
192 return_type: &return_type,
193 always_retry_on_network_error: udf.always_retry_on_network_error,
194 language,
195 is_async: udf.is_async,
196 is_batched: udf.is_batched,
197 })
198 .context("failed to build UDF runtime")?;
199
200 let arrow_convert = UdfArrowConvert {
201 legacy: runtime.is_legacy(),
202 };
203
204 let arg_schema = Arc::new(Schema::new(
205 udf.arg_types
206 .iter()
207 .map(|t| arrow_convert.to_arrow_field("", &DataType::from(t)))
208 .try_collect::<Fields>()?,
209 ));
210
211 let metrics = GLOBAL_METRICS.with_label_values(
212 link.unwrap_or(""),
213 language,
214 name,
215 &FRAGMENT_ID::try_with(ToOwned::to_owned)
217 .unwrap_or(0)
218 .to_string(),
219 );
220
221 Ok(Self {
222 children: udf.children.iter().map(build_child).try_collect()?,
223 arg_types,
224 return_type,
225 arg_schema,
226 runtime,
227 arrow_convert,
228 span: await_tree::span!("udf_call({})", name),
229 metrics,
230 })
231 }
232}
233
234#[derive(Debug, Clone)]
236struct MetricsVec {
237 success_count: LabelGuardedIntCounterVec<4>,
239 failure_count: LabelGuardedIntCounterVec<4>,
241 retry_count: LabelGuardedIntCounterVec<4>,
243 input_chunk_rows: LabelGuardedHistogramVec<4>,
245 latency: LabelGuardedHistogramVec<4>,
247 input_rows: LabelGuardedIntCounterVec<4>,
249 input_bytes: LabelGuardedIntCounterVec<4>,
251 memory_usage_bytes: LabelGuardedIntGaugeVec<5>,
253}
254
255#[derive(Debug, Clone)]
257struct Metrics {
258 success_count: LabelGuardedIntCounter<4>,
260 failure_count: LabelGuardedIntCounter<4>,
262 #[allow(dead_code)]
264 retry_count: LabelGuardedIntCounter<4>,
265 input_chunk_rows: LabelGuardedHistogram<4>,
267 latency: LabelGuardedHistogram<4>,
269 input_rows: LabelGuardedIntCounter<4>,
271 input_bytes: LabelGuardedIntCounter<4>,
273 memory_usage_bytes: LabelGuardedIntGauge<5>,
275}
276
277static GLOBAL_METRICS: LazyLock<MetricsVec> =
279 LazyLock::new(|| MetricsVec::new(&GLOBAL_METRICS_REGISTRY));
280
281impl MetricsVec {
282 fn new(registry: &Registry) -> Self {
283 let labels = &["link", "language", "name", "fragment_id"];
284 let labels5 = &["link", "language", "name", "fragment_id", "instance_id"];
285 let success_count = register_guarded_int_counter_vec_with_registry!(
286 "udf_success_count",
287 "Total number of successful UDF calls",
288 labels,
289 registry
290 )
291 .unwrap();
292 let failure_count = register_guarded_int_counter_vec_with_registry!(
293 "udf_failure_count",
294 "Total number of failed UDF calls",
295 labels,
296 registry
297 )
298 .unwrap();
299 let retry_count = register_guarded_int_counter_vec_with_registry!(
300 "udf_retry_count",
301 "Total number of retried UDF calls",
302 labels,
303 registry
304 )
305 .unwrap();
306 let input_chunk_rows = register_guarded_histogram_vec_with_registry!(
307 "udf_input_chunk_rows",
308 "Input chunk rows of UDF calls",
309 labels,
310 exponential_buckets(1.0, 2.0, 10).unwrap(), registry
312 )
313 .unwrap();
314 let latency = register_guarded_histogram_vec_with_registry!(
315 "udf_latency",
316 "The latency(s) of UDF calls",
317 labels,
318 exponential_buckets(0.000001, 2.0, 30).unwrap(), registry
320 )
321 .unwrap();
322 let input_rows = register_guarded_int_counter_vec_with_registry!(
323 "udf_input_rows",
324 "Total number of input rows of UDF calls",
325 labels,
326 registry
327 )
328 .unwrap();
329 let input_bytes = register_guarded_int_counter_vec_with_registry!(
330 "udf_input_bytes",
331 "Total number of input bytes of UDF calls",
332 labels,
333 registry
334 )
335 .unwrap();
336 let memory_usage_bytes = register_guarded_int_gauge_vec_with_registry!(
337 "udf_memory_usage",
338 "Total memory usage of UDF runtime in bytes",
339 labels5,
340 registry
341 )
342 .unwrap();
343
344 MetricsVec {
345 success_count,
346 failure_count,
347 retry_count,
348 input_chunk_rows,
349 latency,
350 input_rows,
351 input_bytes,
352 memory_usage_bytes,
353 }
354 }
355
356 fn with_label_values(
357 &self,
358 link: &str,
359 language: &str,
360 name: &str,
361 fragment_id: &str,
362 ) -> Metrics {
363 static NEXT_INSTANCE_ID: AtomicU64 = AtomicU64::new(0);
365 let instance_id = NEXT_INSTANCE_ID.fetch_add(1, Ordering::Relaxed).to_string();
366
367 let labels = &[link, language, name, fragment_id];
368 let labels5 = &[link, language, name, fragment_id, &instance_id];
369
370 Metrics {
371 success_count: self.success_count.with_guarded_label_values(labels),
372 failure_count: self.failure_count.with_guarded_label_values(labels),
373 retry_count: self.retry_count.with_guarded_label_values(labels),
374 input_chunk_rows: self.input_chunk_rows.with_guarded_label_values(labels),
375 latency: self.latency.with_guarded_label_values(labels),
376 input_rows: self.input_rows.with_guarded_label_values(labels),
377 input_bytes: self.input_bytes.with_guarded_label_values(labels),
378 memory_usage_bytes: self.memory_usage_bytes.with_guarded_label_values(labels5),
379 }
380 }
381}