risingwave_expr/table_function/
user_defined.rs1use std::sync::Arc;
16
17use anyhow::Context;
18use risingwave_common::array::I32Array;
19use risingwave_common::array::arrow::arrow_schema_udf::{Fields, Schema, SchemaRef};
20use risingwave_common::array::arrow::{UdfArrowConvert, UdfFromArrow, UdfToArrow};
21use risingwave_common::bail;
22
23use super::*;
24use crate::sig::{BuildOptions, UdfImpl, UdfKind};
25
26#[derive(Debug)]
27pub struct UserDefinedTableFunction {
28 children: Vec<BoxedExpression>,
29 arg_schema: SchemaRef,
30 return_type: DataType,
31 runtime: Box<dyn UdfImpl>,
32 arrow_convert: UdfArrowConvert,
33 #[allow(dead_code)]
34 chunk_size: usize,
35}
36
37#[async_trait::async_trait]
38impl TableFunction for UserDefinedTableFunction {
39 fn return_type(&self) -> DataType {
40 self.return_type.clone()
41 }
42
43 async fn eval<'a>(&'a self, input: &'a DataChunk) -> BoxStream<'a, Result<DataChunk>> {
44 self.eval_inner(input)
45 }
46}
47
48impl UserDefinedTableFunction {
49 #[try_stream(boxed, ok = DataChunk, error = ExprError)]
50 async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
51 let mut columns = Vec::with_capacity(self.children.len());
53 for c in &self.children {
54 let val = c.eval(input).await?;
55 columns.push(val);
56 }
57 let direct_input = DataChunk::new(columns, input.visibility().clone());
58
59 let visible_rows = direct_input.visibility().iter_ones().collect::<Vec<_>>();
61 let arrow_input = self
63 .arrow_convert
64 .to_record_batch(self.arg_schema.clone(), &direct_input)?;
65
66 #[for_await]
68 for res in self.runtime.call_table_function(&arrow_input).await? {
69 let output = self.arrow_convert.from_record_batch(&res?)?;
70 self.check_output(&output)?;
71
72 let origin_indices = output
75 .column_at(0)
76 .as_int32()
77 .raw_iter()
78 .map(|idx| visible_rows[idx as usize] as i32)
80 .collect::<I32Array>();
81
82 let output = DataChunk::new(
83 vec![origin_indices.into_ref(), output.column_at(1).clone()],
84 output.visibility().clone(),
85 );
86 yield output;
87 }
88 }
89
90 fn check_output(&self, output: &DataChunk) -> Result<()> {
92 if output.columns().len() != 2 {
93 bail!(
94 "UDF returned {} columns, but expected 2",
95 output.columns().len()
96 );
97 }
98 if output.column_at(0).data_type() != DataType::Int32 {
99 bail!(
100 "UDF returned {:?} at column 0, but expected {:?}",
101 output.column_at(0).data_type(),
102 DataType::Int32,
103 );
104 }
105 if output.column_at(0).as_int32().raw_iter().any(|i| i < 0) {
106 bail!("UDF returned negative row index");
107 }
108 if !output
109 .column_at(1)
110 .data_type()
111 .equals_datatype(&self.return_type)
112 {
113 bail!(
114 "UDF returned {:?} at column 1, but expected {:?}",
115 output.column_at(1).data_type(),
116 &self.return_type,
117 );
118 }
119 Ok(())
120 }
121}
122
123pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<BoxedTableFunction> {
124 let udf = prost.get_udf()?;
125
126 let arg_types = udf.arg_types.iter().map(|t| t.into()).collect::<Vec<_>>();
127 let return_type = DataType::from(prost.get_return_type()?);
128
129 let language = udf.language.as_str();
130 let runtime = udf.runtime.as_deref();
131 let link = udf.link.as_deref();
132
133 let name_in_runtime = udf
134 .name_in_runtime()
135 .expect("SQL UDF won't get here, other UDFs must have `name_in_runtime`");
136
137 let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn;
138 let runtime = build_fn(BuildOptions {
139 kind: UdfKind::Table,
140 body: udf.body.as_deref(),
141 compressed_binary: udf.compressed_binary.as_deref(),
142 link: udf.link.as_deref(),
143 name_in_runtime,
144 arg_names: &udf.arg_names,
145 arg_types: &arg_types,
146 return_type: &return_type,
147 always_retry_on_network_error: false,
148 language,
149 is_async: None,
150 is_batched: None,
151 })
152 .context("failed to build UDF runtime")?;
153
154 let arrow_convert = UdfArrowConvert {
155 legacy: runtime.is_legacy(),
156 };
157 let arg_schema = Arc::new(Schema::new(
158 arg_types
159 .iter()
160 .map(|t| arrow_convert.to_arrow_field("", t))
161 .try_collect::<Fields>()?,
162 ));
163
164 Ok(UserDefinedTableFunction {
165 children: prost.args.iter().map(expr_build_from_prost).try_collect()?,
166 return_type,
167 arg_schema,
168 runtime,
169 arrow_convert,
170 chunk_size,
171 }
172 .boxed())
173}