risingwave_expr_impl/scalar/
array_sum.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::array::{ArrayError, ListRef};
16use risingwave_common::types::{CheckedAdd, Decimal, ScalarRefImpl};
17use risingwave_expr::{ExprError, Result, function};
18
19#[function("array_sum(int2[]) -> int8")]
20fn array_sum_int2(list: ListRef<'_>) -> Result<Option<i64>> {
21    array_sum_general::<i16, i64>(list)
22}
23
24#[function("array_sum(int4[]) -> int8")]
25fn array_sum_int4(list: ListRef<'_>) -> Result<Option<i64>> {
26    array_sum_general::<i32, i64>(list)
27}
28
29#[function("array_sum(int8[]) -> decimal")]
30fn array_sum_int8(list: ListRef<'_>) -> Result<Option<Decimal>> {
31    array_sum_general::<i64, Decimal>(list)
32}
33
34#[function("array_sum(float4[]) -> float4")]
35#[function("array_sum(float8[]) -> float8")]
36#[function("array_sum(decimal[]) -> decimal")]
37#[function("array_sum(interval[]) -> interval")]
38fn array_sum<T>(list: ListRef<'_>) -> Result<Option<T>>
39where
40    T: for<'a> TryFrom<ScalarRefImpl<'a>, Error = ArrayError>,
41    T: Default + From<T> + CheckedAdd<Output = T>,
42{
43    array_sum_general::<T, T>(list)
44}
45
46fn array_sum_general<S, T>(list: ListRef<'_>) -> Result<Option<T>>
47where
48    S: for<'a> TryFrom<ScalarRefImpl<'a>, Error = ArrayError>,
49    T: Default + From<S> + CheckedAdd<Output = T>,
50{
51    if list.iter().flatten().next().is_none() {
52        return Ok(None);
53    }
54    let mut sum = T::default();
55    for e in list.iter().flatten() {
56        let v: S = e.try_into()?;
57        sum = sum
58            .checked_add(v.into())
59            .ok_or_else(|| ExprError::NumericOutOfRange)?;
60    }
61    Ok(Some(sum))
62}