risingwave_expr_impl/scalar/
array_sum.rs1use risingwave_common::array::{ArrayError, ListRef};
16use risingwave_common::types::{CheckedAdd, Decimal, ScalarRefImpl};
17use risingwave_expr::{ExprError, Result, function};
18
19#[function("array_sum(int2[]) -> int8")]
20fn array_sum_int2(list: ListRef<'_>) -> Result<Option<i64>> {
21    array_sum_general::<i16, i64>(list)
22}
23
24#[function("array_sum(int4[]) -> int8")]
25fn array_sum_int4(list: ListRef<'_>) -> Result<Option<i64>> {
26    array_sum_general::<i32, i64>(list)
27}
28
29#[function("array_sum(int8[]) -> decimal")]
30fn array_sum_int8(list: ListRef<'_>) -> Result<Option<Decimal>> {
31    array_sum_general::<i64, Decimal>(list)
32}
33
34#[function("array_sum(float4[]) -> float4")]
35#[function("array_sum(float8[]) -> float8")]
36#[function("array_sum(decimal[]) -> decimal")]
37#[function("array_sum(interval[]) -> interval")]
38fn array_sum<T>(list: ListRef<'_>) -> Result<Option<T>>
39where
40    T: for<'a> TryFrom<ScalarRefImpl<'a>, Error = ArrayError>,
41    T: Default + From<T> + CheckedAdd<Output = T>,
42{
43    array_sum_general::<T, T>(list)
44}
45
46fn array_sum_general<S, T>(list: ListRef<'_>) -> Result<Option<T>>
47where
48    S: for<'a> TryFrom<ScalarRefImpl<'a>, Error = ArrayError>,
49    T: Default + From<S> + CheckedAdd<Output = T>,
50{
51    if list.iter().flatten().next().is_none() {
52        return Ok(None);
53    }
54    let mut sum = T::default();
55    for e in list.iter().flatten() {
56        let v: S = e.try_into()?;
57        sum = sum
58            .checked_add(v.into())
59            .ok_or_else(|| ExprError::NumericOutOfRange)?;
60    }
61    Ok(Some(sum))
62}