risingwave_expr/aggregate/
mod.rs1use std::fmt::Debug;
16use std::ops::Range;
17
18use anyhow::anyhow;
19use downcast_rs::{Downcast, impl_downcast};
20use itertools::Itertools;
21use risingwave_common::array::StreamChunk;
22use risingwave_common::types::{DataType, Datum};
23use risingwave_common_estimate_size::EstimateSize;
24
25use crate::expr::build_from_prost;
26use crate::sig::FuncBuilder;
27use crate::{ExprError, Result};
28
29mod def;
31mod scalar_wrapper;
32mod user_defined;
34
35pub use self::def::*;
36
37#[async_trait::async_trait]
39pub trait AggregateFunction: Send + Sync + 'static {
40 fn return_type(&self) -> DataType;
42
43 fn create_state(&self) -> Result<AggregateState> {
45 Ok(AggregateState::Datum(None))
46 }
47
48 async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()>;
50
51 async fn update_range(
53 &self,
54 state: &mut AggregateState,
55 input: &StreamChunk,
56 range: Range<usize>,
57 ) -> Result<()>;
58
59 async fn get_result(&self, state: &AggregateState) -> Result<Datum>;
61
62 fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
64 match state {
65 AggregateState::Datum(d) => Ok(d.clone()),
66 AggregateState::Any(_) => Err(ExprError::Internal(anyhow!("cannot encode state"))),
67 }
68 }
69
70 fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
72 Ok(AggregateState::Datum(datum))
73 }
74}
75
76#[derive(Debug)]
78pub enum AggregateState {
79 Datum(Datum),
81 Any(Box<dyn AggStateDyn>),
83}
84
85impl EstimateSize for AggregateState {
86 fn estimated_heap_size(&self) -> usize {
87 match self {
88 Self::Datum(d) => d.estimated_heap_size(),
89 Self::Any(a) => std::mem::size_of_val(&**a) + a.estimated_heap_size(),
90 }
91 }
92}
93
94pub trait AggStateDyn: Send + Sync + Debug + EstimateSize + Downcast {}
95
96impl_downcast!(AggStateDyn);
97
98impl AggregateState {
99 pub fn as_datum(&self) -> &Datum {
100 match self {
101 Self::Datum(d) => d,
102 Self::Any(_) => panic!("not datum"),
103 }
104 }
105
106 pub fn as_datum_mut(&mut self) -> &mut Datum {
107 match self {
108 Self::Datum(d) => d,
109 Self::Any(_) => panic!("not datum"),
110 }
111 }
112
113 pub fn downcast_ref<T: AggStateDyn>(&self) -> &T {
114 match self {
115 Self::Datum(_) => panic!("cannot downcast scalar"),
116 Self::Any(a) => a.downcast_ref::<T>().expect("cannot downcast"),
117 }
118 }
119
120 pub fn downcast_mut<T: AggStateDyn>(&mut self) -> &mut T {
121 match self {
122 Self::Datum(_) => panic!("cannot downcast scalar"),
123 Self::Any(a) => a.downcast_mut::<T>().expect("cannot downcast"),
124 }
125 }
126}
127
128pub type BoxedAggregateFunction = Box<dyn AggregateFunction>;
129
130pub fn build_append_only(agg: &AggCall) -> Result<BoxedAggregateFunction> {
132 build(agg, true)
133}
134
135pub fn build_retractable(agg: &AggCall) -> Result<BoxedAggregateFunction> {
137 build(agg, false)
138}
139
140pub fn build(agg: &AggCall, prefer_append_only: bool) -> Result<BoxedAggregateFunction> {
148 let kind = match &agg.agg_type {
150 AggType::UserDefined(udf) => {
151 return user_defined::new_user_defined(&agg.return_type, udf);
152 }
153 AggType::WrapScalar(scalar) => {
154 return Ok(Box::new(scalar_wrapper::ScalarWrapper::new(
155 agg.args.arg_types()[0].clone(),
156 build_from_prost(scalar)?,
157 )));
158 }
159 AggType::Builtin(kind) => kind,
160 };
161
162 let sig = crate::sig::FUNCTION_REGISTRY.get(*kind, agg.args.arg_types(), &agg.return_type)?;
164
165 if let FuncBuilder::Aggregate {
166 append_only: Some(f),
167 ..
168 } = sig.build
169 && prefer_append_only
170 {
171 return f(agg);
172 }
173 sig.build_aggregate(agg)
174}