risingwave_expr/aggregate/
user_defined.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use anyhow::Context;
18use risingwave_common::array::Op;
19use risingwave_common::array::arrow::arrow_array_udf::ArrayRef;
20use risingwave_common::array::arrow::arrow_schema_udf::{Field, Fields, Schema, SchemaRef};
21use risingwave_common::array::arrow::{UdfArrowConvert, UdfFromArrow, UdfToArrow};
22use risingwave_common::bitmap::Bitmap;
23use risingwave_pb::expr::PbUserDefinedFunctionMetadata;
24
25use super::*;
26use crate::sig::{BuildOptions, UdfImpl, UdfKind};
27
28#[derive(Debug)]
29pub struct UserDefinedAggregateFunction {
30    arg_schema: SchemaRef,
31    return_type: DataType,
32    return_field: Field,
33    state_field: Field,
34    runtime: Box<dyn UdfImpl>,
35}
36
37#[async_trait::async_trait]
38impl AggregateFunction for UserDefinedAggregateFunction {
39    fn return_type(&self) -> DataType {
40        self.return_type.clone()
41    }
42
43    /// Creates an initial state of the aggregate function.
44    fn create_state(&self) -> Result<AggregateState> {
45        // FIXME(eric): This is bad. Let's make `create_state` async if someday we allow async UDAF
46        futures::executor::block_on(async {
47            let state = self.runtime.call_agg_create_state().await?;
48            Ok(AggregateState::Any(Box::new(State(state))))
49        })
50    }
51
52    /// Update the state with multiple rows.
53    async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
54        let state = &mut state.downcast_mut::<State>().0;
55        let ops = input
56            .visibility()
57            .iter_ones()
58            .map(|i| Some(matches!(input.ops()[i], Op::Delete | Op::UpdateDelete)))
59            .collect();
60        // this will drop invisible rows
61        let arrow_input = UdfArrowConvert::default()
62            .to_record_batch(self.arg_schema.clone(), input.data_chunk())?;
63        let new_state = self
64            .runtime
65            .call_agg_accumulate_or_retract(state, &ops, &arrow_input)
66            .await?;
67        *state = new_state;
68        Ok(())
69    }
70
71    /// Update the state with a range of rows.
72    async fn update_range(
73        &self,
74        state: &mut AggregateState,
75        input: &StreamChunk,
76        range: Range<usize>,
77    ) -> Result<()> {
78        // XXX(runji): this may be inefficient
79        let vis = input.visibility() & Bitmap::from_range(input.capacity(), range);
80        let input = input.clone_with_vis(vis);
81        self.update(state, &input).await
82    }
83
84    /// Get aggregate result from the state.
85    async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
86        let state = &state.downcast_ref::<State>().0;
87        let arrow_output = self.runtime.call_agg_finish(state).await?;
88        let output = UdfArrowConvert::default().from_array(&self.return_field, &arrow_output)?;
89        Ok(output.datum_at(0))
90    }
91
92    /// Encode the state into a datum that can be stored in state table.
93    fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
94        let state = &state.downcast_ref::<State>().0;
95        let state = UdfArrowConvert::default().from_array(&self.state_field, state)?;
96        Ok(state.datum_at(0))
97    }
98
99    /// Decode the state from a datum in state table.
100    fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
101        let array = {
102            let mut builder = DataType::Bytea.create_array_builder(1);
103            builder.append(datum);
104            builder.finish()
105        };
106        let state = UdfArrowConvert::default().to_array(self.state_field.data_type(), &array)?;
107        Ok(AggregateState::Any(Box::new(State(state))))
108    }
109}
110
111// In arrow-udf, aggregate state is represented as an `ArrayRef`.
112// To avoid unnecessary conversion between `ArrayRef` and `Datum`,
113// we store `ArrayRef` directly in our `AggregateState`.
114#[derive(Debug)]
115struct State(ArrayRef);
116
117impl EstimateSize for State {
118    fn estimated_heap_size(&self) -> usize {
119        self.0.get_array_memory_size()
120    }
121}
122
123impl AggStateDyn for State {}
124
125/// Create a new user-defined aggregate function.
126pub fn new_user_defined(
127    return_type: &DataType,
128    udf: &PbUserDefinedFunctionMetadata,
129) -> Result<BoxedAggregateFunction> {
130    let arg_types = udf.arg_types.iter().map(|t| t.into()).collect::<Vec<_>>();
131    let language = udf.language.as_str();
132    let runtime = udf.runtime.as_deref();
133    let link = udf.link.as_deref();
134
135    let name_in_runtime = udf
136        .name_in_runtime()
137        .expect("SQL UDF won't get here, other UDFs must have `name_in_runtime`");
138
139    let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn;
140    let runtime = build_fn(BuildOptions {
141        kind: UdfKind::Aggregate,
142        body: udf.body.as_deref(),
143        compressed_binary: udf.compressed_binary.as_deref(),
144        link: udf.link.as_deref(),
145        name_in_runtime,
146        arg_names: &udf.arg_names,
147        arg_types: &arg_types,
148        return_type,
149        always_retry_on_network_error: false,
150        language,
151        is_async: None,
152        is_batched: None,
153    })
154    .context("failed to build UDF runtime")?;
155
156    // legacy UDF runtimes do not support aggregate functions,
157    // so we can assume that the runtime is not legacy
158    let arrow_convert = UdfArrowConvert::default();
159    let arg_schema = Arc::new(Schema::new(
160        arg_types
161            .iter()
162            .map(|t| arrow_convert.to_arrow_field("", t))
163            .try_collect::<_, Fields, _>()?,
164    ));
165
166    Ok(Box::new(UserDefinedAggregateFunction {
167        return_field: arrow_convert.to_arrow_field("", return_type)?,
168        state_field: Field::new(
169            "state",
170            risingwave_common::array::arrow::arrow_schema_udf::DataType::Binary,
171            true,
172        ),
173        return_type: return_type.clone(),
174        arg_schema,
175        runtime,
176    }))
177}