risingwave_expr/aggregate/
user_defined.rs1use std::sync::Arc;
16
17use anyhow::Context;
18use risingwave_common::array::Op;
19use risingwave_common::array::arrow::arrow_array_udf::ArrayRef;
20use risingwave_common::array::arrow::arrow_schema_udf::{Field, Fields, Schema, SchemaRef};
21use risingwave_common::array::arrow::{UdfArrowConvert, UdfFromArrow, UdfToArrow};
22use risingwave_common::bitmap::Bitmap;
23use risingwave_pb::expr::PbUserDefinedFunctionMetadata;
24
25use super::*;
26use crate::sig::{BuildOptions, UdfImpl, UdfKind};
27
28#[derive(Debug)]
29pub struct UserDefinedAggregateFunction {
30 arg_schema: SchemaRef,
31 return_type: DataType,
32 return_field: Field,
33 state_field: Field,
34 runtime: Box<dyn UdfImpl>,
35}
36
37#[async_trait::async_trait]
38impl AggregateFunction for UserDefinedAggregateFunction {
39 fn return_type(&self) -> DataType {
40 self.return_type.clone()
41 }
42
43 fn create_state(&self) -> Result<AggregateState> {
45 futures::executor::block_on(async {
47 let state = self.runtime.call_agg_create_state().await?;
48 Ok(AggregateState::Any(Box::new(State(state))))
49 })
50 }
51
52 async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
54 let state = &mut state.downcast_mut::<State>().0;
55 let ops = input
56 .visibility()
57 .iter_ones()
58 .map(|i| Some(matches!(input.ops()[i], Op::Delete | Op::UpdateDelete)))
59 .collect();
60 let arrow_input = UdfArrowConvert::default()
62 .to_record_batch(self.arg_schema.clone(), input.data_chunk())?;
63 let new_state = self
64 .runtime
65 .call_agg_accumulate_or_retract(state, &ops, &arrow_input)
66 .await?;
67 *state = new_state;
68 Ok(())
69 }
70
71 async fn update_range(
73 &self,
74 state: &mut AggregateState,
75 input: &StreamChunk,
76 range: Range<usize>,
77 ) -> Result<()> {
78 let vis = input.visibility() & Bitmap::from_range(input.capacity(), range);
80 let input = input.clone_with_vis(vis);
81 self.update(state, &input).await
82 }
83
84 async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
86 let state = &state.downcast_ref::<State>().0;
87 let arrow_output = self.runtime.call_agg_finish(state).await?;
88 let output = UdfArrowConvert::default().from_array(&self.return_field, &arrow_output)?;
89 Ok(output.datum_at(0))
90 }
91
92 fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
94 let state = &state.downcast_ref::<State>().0;
95 let state = UdfArrowConvert::default().from_array(&self.state_field, state)?;
96 Ok(state.datum_at(0))
97 }
98
99 fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
101 let array = {
102 let mut builder = DataType::Bytea.create_array_builder(1);
103 builder.append(datum);
104 builder.finish()
105 };
106 let state = UdfArrowConvert::default().to_array(self.state_field.data_type(), &array)?;
107 Ok(AggregateState::Any(Box::new(State(state))))
108 }
109}
110
111#[derive(Debug)]
115struct State(ArrayRef);
116
117impl EstimateSize for State {
118 fn estimated_heap_size(&self) -> usize {
119 self.0.get_array_memory_size()
120 }
121}
122
123impl AggStateDyn for State {}
124
125pub fn new_user_defined(
127 return_type: &DataType,
128 udf: &PbUserDefinedFunctionMetadata,
129) -> Result<BoxedAggregateFunction> {
130 let arg_types = udf.arg_types.iter().map(|t| t.into()).collect::<Vec<_>>();
131 let language = udf.language.as_str();
132 let runtime = udf.runtime.as_deref();
133 let link = udf.link.as_deref();
134
135 let name_in_runtime = udf
136 .name_in_runtime()
137 .expect("SQL UDF won't get here, other UDFs must have `name_in_runtime`");
138
139 let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn;
140 let runtime = build_fn(BuildOptions {
141 kind: UdfKind::Aggregate,
142 body: udf.body.as_deref(),
143 compressed_binary: udf.compressed_binary.as_deref(),
144 link: udf.link.as_deref(),
145 name_in_runtime,
146 arg_names: &udf.arg_names,
147 arg_types: &arg_types,
148 return_type,
149 always_retry_on_network_error: false,
150 language,
151 is_async: None,
152 is_batched: None,
153 })
154 .context("failed to build UDF runtime")?;
155
156 let arrow_convert = UdfArrowConvert::default();
159 let arg_schema = Arc::new(Schema::new(
160 arg_types
161 .iter()
162 .map(|t| arrow_convert.to_arrow_field("", t))
163 .try_collect::<_, Fields, _>()?,
164 ));
165
166 Ok(Box::new(UserDefinedAggregateFunction {
167 return_field: arrow_convert.to_arrow_field("", return_type)?,
168 state_field: Field::new(
169 "state",
170 risingwave_common::array::arrow::arrow_schema_udf::DataType::Binary,
171 true,
172 ),
173 return_type: return_type.clone(),
174 arg_schema,
175 runtime,
176 }))
177}