risingwave_expr/aggregate/
user_defined.rsuse std::sync::Arc;
use anyhow::Context;
use risingwave_common::array::arrow::arrow_array_udf::ArrayRef;
use risingwave_common::array::arrow::arrow_schema_udf::{Field, Fields, Schema, SchemaRef};
use risingwave_common::array::arrow::{UdfArrowConvert, UdfFromArrow, UdfToArrow};
use risingwave_common::array::Op;
use risingwave_common::bitmap::Bitmap;
use risingwave_pb::expr::PbUserDefinedFunctionMetadata;
use super::*;
use crate::sig::{UdfImpl, UdfKind, UdfOptions};
#[derive(Debug)]
pub struct UserDefinedAggregateFunction {
arg_schema: SchemaRef,
return_type: DataType,
return_field: Field,
state_field: Field,
runtime: Box<dyn UdfImpl>,
}
#[async_trait::async_trait]
impl AggregateFunction for UserDefinedAggregateFunction {
fn return_type(&self) -> DataType {
self.return_type.clone()
}
fn create_state(&self) -> Result<AggregateState> {
let state = self.runtime.call_agg_create_state()?;
Ok(AggregateState::Any(Box::new(State(state))))
}
async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
let state = &mut state.downcast_mut::<State>().0;
let ops = input
.visibility()
.iter_ones()
.map(|i| Some(matches!(input.ops()[i], Op::Delete | Op::UpdateDelete)))
.collect();
let arrow_input = UdfArrowConvert::default()
.to_record_batch(self.arg_schema.clone(), input.data_chunk())?;
let new_state = self
.runtime
.call_agg_accumulate_or_retract(state, &ops, &arrow_input)?;
*state = new_state;
Ok(())
}
async fn update_range(
&self,
state: &mut AggregateState,
input: &StreamChunk,
range: Range<usize>,
) -> Result<()> {
let vis = input.visibility() & Bitmap::from_range(input.capacity(), range);
let input = input.clone_with_vis(vis);
self.update(state, &input).await
}
async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
let state = &state.downcast_ref::<State>().0;
let arrow_output = self.runtime.call_agg_finish(state)?;
let output = UdfArrowConvert::default().from_array(&self.return_field, &arrow_output)?;
Ok(output.datum_at(0))
}
fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
let state = &state.downcast_ref::<State>().0;
let state = UdfArrowConvert::default().from_array(&self.state_field, state)?;
Ok(state.datum_at(0))
}
fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
let array = {
let mut builder = DataType::Bytea.create_array_builder(1);
builder.append(datum);
builder.finish()
};
let state = UdfArrowConvert::default().to_array(self.state_field.data_type(), &array)?;
Ok(AggregateState::Any(Box::new(State(state))))
}
}
#[derive(Debug)]
struct State(ArrayRef);
impl EstimateSize for State {
fn estimated_heap_size(&self) -> usize {
self.0.get_array_memory_size()
}
}
impl AggStateDyn for State {}
pub fn new_user_defined(
return_type: &DataType,
udf: &PbUserDefinedFunctionMetadata,
) -> Result<BoxedAggregateFunction> {
let identifier = udf.get_identifier()?;
let language = udf.language.as_str();
let runtime = udf.runtime.as_deref();
let link = udf.link.as_deref();
let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn;
let runtime = build_fn(UdfOptions {
kind: UdfKind::Aggregate,
body: udf.body.as_deref(),
compressed_binary: udf.compressed_binary.as_deref(),
link: udf.link.as_deref(),
identifier,
arg_names: &udf.arg_names,
return_type,
always_retry_on_network_error: false,
})
.context("failed to build UDF runtime")?;
let arrow_convert = UdfArrowConvert::default();
let arg_schema = Arc::new(Schema::new(
udf.arg_types
.iter()
.map(|t| arrow_convert.to_arrow_field("", &DataType::from(t)))
.try_collect::<_, Fields, _>()?,
));
Ok(Box::new(UserDefinedAggregateFunction {
return_field: arrow_convert.to_arrow_field("", return_type)?,
state_field: Field::new(
"state",
risingwave_common::array::arrow::arrow_schema_udf::DataType::Binary,
true,
),
return_type: return_type.clone(),
arg_schema,
runtime,
}))
}