risingwave_expr/aggregate/
mod.rsuse std::fmt::Debug;
use std::ops::Range;
use anyhow::anyhow;
use downcast_rs::{impl_downcast, Downcast};
use itertools::Itertools;
use risingwave_common::array::StreamChunk;
use risingwave_common::types::{DataType, Datum};
use risingwave_common_estimate_size::EstimateSize;
use crate::expr::build_from_prost;
use crate::sig::FuncBuilder;
use crate::{ExprError, Result};
mod def;
mod scalar_wrapper;
mod user_defined;
pub use self::def::*;
#[async_trait::async_trait]
pub trait AggregateFunction: Send + Sync + 'static {
fn return_type(&self) -> DataType;
fn create_state(&self) -> Result<AggregateState> {
Ok(AggregateState::Datum(None))
}
async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()>;
async fn update_range(
&self,
state: &mut AggregateState,
input: &StreamChunk,
range: Range<usize>,
) -> Result<()>;
async fn get_result(&self, state: &AggregateState) -> Result<Datum>;
fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
match state {
AggregateState::Datum(d) => Ok(d.clone()),
AggregateState::Any(_) => Err(ExprError::Internal(anyhow!("cannot encode state"))),
}
}
fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
Ok(AggregateState::Datum(datum))
}
}
#[derive(Debug)]
pub enum AggregateState {
Datum(Datum),
Any(Box<dyn AggStateDyn>),
}
impl EstimateSize for AggregateState {
fn estimated_heap_size(&self) -> usize {
match self {
Self::Datum(d) => d.estimated_heap_size(),
Self::Any(a) => std::mem::size_of_val(&**a) + a.estimated_heap_size(),
}
}
}
pub trait AggStateDyn: Send + Sync + Debug + EstimateSize + Downcast {}
impl_downcast!(AggStateDyn);
impl AggregateState {
pub fn as_datum(&self) -> &Datum {
match self {
Self::Datum(d) => d,
Self::Any(_) => panic!("not datum"),
}
}
pub fn as_datum_mut(&mut self) -> &mut Datum {
match self {
Self::Datum(d) => d,
Self::Any(_) => panic!("not datum"),
}
}
pub fn downcast_ref<T: AggStateDyn>(&self) -> &T {
match self {
Self::Datum(_) => panic!("cannot downcast scalar"),
Self::Any(a) => a.downcast_ref::<T>().expect("cannot downcast"),
}
}
pub fn downcast_mut<T: AggStateDyn>(&mut self) -> &mut T {
match self {
Self::Datum(_) => panic!("cannot downcast scalar"),
Self::Any(a) => a.downcast_mut::<T>().expect("cannot downcast"),
}
}
}
pub type BoxedAggregateFunction = Box<dyn AggregateFunction>;
pub fn build_append_only(agg: &AggCall) -> Result<BoxedAggregateFunction> {
build(agg, true)
}
pub fn build_retractable(agg: &AggCall) -> Result<BoxedAggregateFunction> {
build(agg, false)
}
pub fn build(agg: &AggCall, prefer_append_only: bool) -> Result<BoxedAggregateFunction> {
let kind = match &agg.agg_type {
AggType::UserDefined(udf) => {
return user_defined::new_user_defined(&agg.return_type, udf);
}
AggType::WrapScalar(scalar) => {
return Ok(Box::new(scalar_wrapper::ScalarWrapper::new(
agg.args.arg_types()[0].clone(),
build_from_prost(scalar)?,
)));
}
AggType::Builtin(kind) => kind,
};
let sig = crate::sig::FUNCTION_REGISTRY.get(*kind, agg.args.arg_types(), &agg.return_type)?;
if let FuncBuilder::Aggregate {
append_only: Some(f),
..
} = sig.build
&& prefer_append_only
{
return f(agg);
}
sig.build_aggregate(agg)
}