risingwave_expr/aggregate/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::ops::Range;
17
18use anyhow::anyhow;
19use downcast_rs::{Downcast, impl_downcast};
20use itertools::Itertools;
21use risingwave_common::array::StreamChunk;
22use risingwave_common::types::{DataType, Datum};
23use risingwave_common_estimate_size::EstimateSize;
24
25use crate::expr::build_from_prost;
26use crate::sig::FuncBuilder;
27use crate::{ExprError, Result};
28
29// aggregate definition
30mod def;
31mod scalar_wrapper;
32// user defined aggregate function
33mod user_defined;
34
35pub use self::def::*;
36
37/// A trait over all aggregate functions.
38#[async_trait::async_trait]
39pub trait AggregateFunction: Send + Sync + 'static {
40    /// Returns the return type of the aggregate function.
41    fn return_type(&self) -> DataType;
42
43    /// Creates an initial state of the aggregate function.
44    fn create_state(&self) -> Result<AggregateState> {
45        Ok(AggregateState::Datum(None))
46    }
47
48    /// Update the state with multiple rows.
49    async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()>;
50
51    /// Update the state with a range of rows.
52    async fn update_range(
53        &self,
54        state: &mut AggregateState,
55        input: &StreamChunk,
56        range: Range<usize>,
57    ) -> Result<()>;
58
59    /// Get aggregate result from the state.
60    async fn get_result(&self, state: &AggregateState) -> Result<Datum>;
61
62    /// Encode the state into a datum that can be stored in state table.
63    fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
64        match state {
65            AggregateState::Datum(d) => Ok(d.clone()),
66            AggregateState::Any(_) => Err(ExprError::Internal(anyhow!("cannot encode state"))),
67        }
68    }
69
70    /// Decode the state from a datum in state table.
71    fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
72        Ok(AggregateState::Datum(datum))
73    }
74}
75
76/// Intermediate state of an aggregate function.
77#[derive(Debug)]
78pub enum AggregateState {
79    /// A scalar value.
80    Datum(Datum),
81    /// A state of any type.
82    Any(Box<dyn AggStateDyn>),
83}
84
85impl EstimateSize for AggregateState {
86    fn estimated_heap_size(&self) -> usize {
87        match self {
88            Self::Datum(d) => d.estimated_heap_size(),
89            Self::Any(a) => std::mem::size_of_val(&**a) + a.estimated_heap_size(),
90        }
91    }
92}
93
94pub trait AggStateDyn: Send + Sync + Debug + EstimateSize + Downcast {}
95
96impl_downcast!(AggStateDyn);
97
98impl AggregateState {
99    pub fn as_datum(&self) -> &Datum {
100        match self {
101            Self::Datum(d) => d,
102            Self::Any(_) => panic!("not datum"),
103        }
104    }
105
106    pub fn as_datum_mut(&mut self) -> &mut Datum {
107        match self {
108            Self::Datum(d) => d,
109            Self::Any(_) => panic!("not datum"),
110        }
111    }
112
113    pub fn downcast_ref<T: AggStateDyn>(&self) -> &T {
114        match self {
115            Self::Datum(_) => panic!("cannot downcast scalar"),
116            Self::Any(a) => a.downcast_ref::<T>().expect("cannot downcast"),
117        }
118    }
119
120    pub fn downcast_mut<T: AggStateDyn>(&mut self) -> &mut T {
121        match self {
122            Self::Datum(_) => panic!("cannot downcast scalar"),
123            Self::Any(a) => a.downcast_mut::<T>().expect("cannot downcast"),
124        }
125    }
126}
127
128pub type BoxedAggregateFunction = Box<dyn AggregateFunction>;
129
130/// Build an append-only `Aggregator` from `AggCall`.
131pub fn build_append_only(agg: &AggCall) -> Result<BoxedAggregateFunction> {
132    build(agg, true)
133}
134
135/// Build a retractable `Aggregator` from `AggCall`.
136pub fn build_retractable(agg: &AggCall) -> Result<BoxedAggregateFunction> {
137    build(agg, false)
138}
139
140/// Build an aggregate function.
141///
142/// If `prefer_append_only` is true, and both append-only and retractable implementations exist,
143/// the append-only version will be used.
144///
145/// NOTE: This function ignores argument indices, `column_orders`, `filter` and `distinct` in
146/// `AggCall`. Such operations should be done in batch or streaming executors.
147pub fn build(agg: &AggCall, prefer_append_only: bool) -> Result<BoxedAggregateFunction> {
148    // handle special kinds
149    let kind = match &agg.agg_type {
150        AggType::UserDefined(udf) => {
151            return user_defined::new_user_defined(&agg.return_type, udf);
152        }
153        AggType::WrapScalar(scalar) => {
154            return Ok(Box::new(scalar_wrapper::ScalarWrapper::new(
155                agg.args.arg_types()[0].clone(),
156                build_from_prost(scalar)?,
157            )));
158        }
159        AggType::Builtin(kind) => kind,
160    };
161
162    // find the signature for builtin aggregation
163    let sig = crate::sig::FUNCTION_REGISTRY.get(*kind, agg.args.arg_types(), &agg.return_type)?;
164
165    if let FuncBuilder::Aggregate {
166        append_only: Some(f),
167        ..
168    } = sig.build
169        && prefer_append_only
170    {
171        return f(agg);
172    }
173    sig.build_aggregate(agg)
174}