risingwave_expr/table_function/
user_defined.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use anyhow::Context;
18use risingwave_common::array::I32Array;
19use risingwave_common::array::arrow::arrow_schema_udf::{Fields, Schema, SchemaRef};
20use risingwave_common::array::arrow::{UdfArrowConvert, UdfFromArrow, UdfToArrow};
21use risingwave_common::bail;
22
23use super::*;
24use crate::sig::{BuildOptions, UdfImpl, UdfKind};
25
26#[derive(Debug)]
27pub struct UserDefinedTableFunction {
28    children: Vec<BoxedExpression>,
29    arg_schema: SchemaRef,
30    return_type: DataType,
31    runtime: Box<dyn UdfImpl>,
32    arrow_convert: UdfArrowConvert,
33    #[allow(dead_code)]
34    chunk_size: usize,
35}
36
37#[async_trait::async_trait]
38impl TableFunction for UserDefinedTableFunction {
39    fn return_type(&self) -> DataType {
40        self.return_type.clone()
41    }
42
43    async fn eval<'a>(&'a self, input: &'a DataChunk) -> BoxStream<'a, Result<DataChunk>> {
44        self.eval_inner(input)
45    }
46}
47
48impl UserDefinedTableFunction {
49    #[try_stream(boxed, ok = DataChunk, error = ExprError)]
50    async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
51        // evaluate children expressions
52        let mut columns = Vec::with_capacity(self.children.len());
53        for c in &self.children {
54            let val = c.eval(input).await?;
55            columns.push(val);
56        }
57        let direct_input = DataChunk::new(columns, input.visibility().clone());
58
59        // compact the input chunk and record the row mapping
60        let visible_rows = direct_input.visibility().iter_ones().collect::<Vec<_>>();
61        // this will drop invisible rows
62        let arrow_input = self
63            .arrow_convert
64            .to_record_batch(self.arg_schema.clone(), &direct_input)?;
65
66        // call UDTF
67        #[for_await]
68        for res in self.runtime.call_table_function(&arrow_input).await? {
69            let output = self.arrow_convert.from_record_batch(&res?)?;
70            self.check_output(&output)?;
71
72            // we send the compacted input to UDF, so we need to map the row indices back to the
73            // original input
74            let origin_indices = output
75                .column_at(0)
76                .as_int32()
77                .raw_iter()
78                // we have checked all indices are non-negative
79                .map(|idx| visible_rows[idx as usize] as i32)
80                .collect::<I32Array>();
81
82            let output = DataChunk::new(
83                vec![origin_indices.into_ref(), output.column_at(1).clone()],
84                output.visibility().clone(),
85            );
86            yield output;
87        }
88    }
89
90    /// Check if the output chunk is valid.
91    fn check_output(&self, output: &DataChunk) -> Result<()> {
92        if output.columns().len() != 2 {
93            bail!(
94                "UDF returned {} columns, but expected 2",
95                output.columns().len()
96            );
97        }
98        if output.column_at(0).data_type() != DataType::Int32 {
99            bail!(
100                "UDF returned {:?} at column 0, but expected {:?}",
101                output.column_at(0).data_type(),
102                DataType::Int32,
103            );
104        }
105        if output.column_at(0).as_int32().raw_iter().any(|i| i < 0) {
106            bail!("UDF returned negative row index");
107        }
108        if !output
109            .column_at(1)
110            .data_type()
111            .equals_datatype(&self.return_type)
112        {
113            bail!(
114                "UDF returned {:?} at column 1, but expected {:?}",
115                output.column_at(1).data_type(),
116                &self.return_type,
117            );
118        }
119        Ok(())
120    }
121}
122
123pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<BoxedTableFunction> {
124    let udf = prost.get_udf()?;
125
126    let arg_types = udf.arg_types.iter().map(|t| t.into()).collect::<Vec<_>>();
127    let return_type = DataType::from(prost.get_return_type()?);
128
129    let language = udf.language.as_str();
130    let runtime = udf.runtime.as_deref();
131    let link = udf.link.as_deref();
132
133    let name_in_runtime = udf
134        .name_in_runtime()
135        .expect("SQL UDF won't get here, other UDFs must have `name_in_runtime`");
136
137    let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn;
138    let runtime = build_fn(BuildOptions {
139        kind: UdfKind::Table,
140        body: udf.body.as_deref(),
141        compressed_binary: udf.compressed_binary.as_deref(),
142        link: udf.link.as_deref(),
143        name_in_runtime,
144        arg_names: &udf.arg_names,
145        arg_types: &arg_types,
146        return_type: &return_type,
147        always_retry_on_network_error: false,
148        language,
149        is_async: None,
150        is_batched: None,
151    })
152    .context("failed to build UDF runtime")?;
153
154    let arrow_convert = UdfArrowConvert {
155        legacy: runtime.is_legacy(),
156    };
157    let arg_schema = Arc::new(Schema::new(
158        arg_types
159            .iter()
160            .map(|t| arrow_convert.to_arrow_field("", t))
161            .try_collect::<Fields>()?,
162    ));
163
164    Ok(UserDefinedTableFunction {
165        children: prost.args.iter().map(expr_build_from_prost).try_collect()?,
166        return_type,
167        arg_schema,
168        runtime,
169        arrow_convert,
170        chunk_size,
171    }
172    .boxed())
173}