risingwave_expr_impl/scalar/
array_sum.rs1use risingwave_common::array::{ArrayError, ListRef};
16use risingwave_common::types::{CheckedAdd, Decimal, ScalarRefImpl};
17use risingwave_expr::{ExprError, Result, function};
18
19#[function("array_sum(int2[]) -> int8")]
20fn array_sum_int2(list: ListRef<'_>) -> Result<Option<i64>> {
21 array_sum_general::<i16, i64>(list)
22}
23
24#[function("array_sum(int4[]) -> int8")]
25fn array_sum_int4(list: ListRef<'_>) -> Result<Option<i64>> {
26 array_sum_general::<i32, i64>(list)
27}
28
29#[function("array_sum(int8[]) -> decimal")]
30fn array_sum_int8(list: ListRef<'_>) -> Result<Option<Decimal>> {
31 array_sum_general::<i64, Decimal>(list)
32}
33
34#[function("array_sum(float4[]) -> float4")]
35#[function("array_sum(float8[]) -> float8")]
36#[function("array_sum(decimal[]) -> decimal")]
37#[function("array_sum(interval[]) -> interval")]
38fn array_sum<T>(list: ListRef<'_>) -> Result<Option<T>>
39where
40 T: for<'a> TryFrom<ScalarRefImpl<'a>, Error = ArrayError>,
41 T: Default + From<T> + CheckedAdd<Output = T>,
42{
43 array_sum_general::<T, T>(list)
44}
45
46fn array_sum_general<S, T>(list: ListRef<'_>) -> Result<Option<T>>
47where
48 S: for<'a> TryFrom<ScalarRefImpl<'a>, Error = ArrayError>,
49 T: Default + From<S> + CheckedAdd<Output = T>,
50{
51 if list.iter().flatten().next().is_none() {
52 return Ok(None);
53 }
54 let mut sum = T::default();
55 for e in list.iter().flatten() {
56 let v: S = e.try_into()?;
57 sum = sum
58 .checked_add(v.into())
59 .ok_or_else(|| ExprError::NumericOutOfRange)?;
60 }
61 Ok(Some(sum))
62}