risingwave_expr_impl/scalar/
round.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::types::{Decimal, F64};
16use risingwave_expr::{ExprError, Result, function};
17
18#[function("round_digit(decimal, int4) -> decimal")]
19pub fn round_digits(input: Decimal, digits: i32) -> Result<Decimal> {
20    if digits < 0 {
21        input
22            .round_left_ties_away(digits.unsigned_abs())
23            .ok_or(ExprError::NumericOverflow)
24    } else {
25        // rust_decimal can only handle up to 28 digits of scale
26        Ok(input.round_dp_ties_away((digits as u32).min(Decimal::MAX_PRECISION.into())))
27    }
28}
29
30#[function("ceil(float8) -> float8")]
31pub fn ceil_f64(input: F64) -> F64 {
32    f64::ceil(input.0).into()
33}
34
35#[function("ceil(decimal) -> decimal")]
36pub fn ceil_decimal(input: Decimal) -> Decimal {
37    input.ceil()
38}
39
40#[function("floor(float8) -> float8")]
41pub fn floor_f64(input: F64) -> F64 {
42    f64::floor(input.0).into()
43}
44
45#[function("floor(decimal) -> decimal")]
46pub fn floor_decimal(input: Decimal) -> Decimal {
47    input.floor()
48}
49
50#[function("trunc(float8) -> float8")]
51pub fn trunc_f64(input: F64) -> F64 {
52    f64::trunc(input.0).into()
53}
54
55#[function("trunc(decimal) -> decimal")]
56pub fn trunc_decimal(input: Decimal) -> Decimal {
57    input.trunc()
58}
59
60// Ties are broken by rounding away from zero
61#[function("round(float8) -> float8")]
62pub fn round_f64(input: F64) -> F64 {
63    f64::round_ties_even(input.0).into()
64}
65
66// Ties are broken by rounding away from zero
67#[function("round(decimal) -> decimal")]
68pub fn round_decimal(input: Decimal) -> Decimal {
69    input.round_dp_ties_away(0)
70}
71
72#[cfg(test)]
73mod tests {
74    use std::str::FromStr;
75
76    use crate::scalar::round::*;
77
78    fn do_test(input: &str, digits: i32, expected_output: Option<&str>) {
79        let v = Decimal::from_str(input).unwrap();
80        let rounded_value = round_digits(v, digits).ok();
81        assert_eq!(
82            expected_output,
83            rounded_value.as_ref().map(ToString::to_string).as_deref()
84        );
85    }
86
87    #[test]
88    fn test_round_digits() {
89        do_test("21.666666666666666666666666667", 4, Some("21.6667"));
90        do_test("84818.33333333333333333333333", 4, Some("84818.3333"));
91        do_test("84818.15", 1, Some("84818.2"));
92        do_test("21.372736", -1, Some("20"));
93        do_test("-79228162514264337593543950335", -30, Some("0"));
94        do_test("-79228162514264337593543950335", -29, None);
95        do_test("-79228162514264337593543950335", -28, None);
96        do_test(
97            "-79228162514264337593543950335",
98            -27,
99            Some("-79000000000000000000000000000"),
100        );
101        do_test("-792.28162514264337593543950335", -4, Some("0"));
102        do_test("-792.28162514264337593543950335", -3, Some("-1000"));
103        do_test("-792.28162514264337593543950335", -2, Some("-800"));
104        do_test("-792.28162514264337593543950335", -1, Some("-790"));
105        do_test("-50000000000000000000000000000", -29, None);
106        do_test("-49999999999999999999999999999", -29, Some("0"));
107        do_test("-500.00000000000000000000000000", -3, Some("-1000"));
108        do_test("-499.99999999999999999999999999", -3, Some("0"));
109        // When digit extends past original scale, it should just return original scale.
110        // Intuitively, it does not make sense after rounding `0` it becomes `0.000`. Precision
111        // should always be less or equal, not more.
112        do_test("0", 340, Some("0"));
113    }
114
115    #[test]
116    fn test_round_f64() {
117        assert_eq!(ceil_f64(F64::from(42.2)), F64::from(43.0));
118        assert_eq!(ceil_f64(F64::from(-42.8)), F64::from(-42.0));
119
120        assert_eq!(floor_f64(F64::from(42.8)), F64::from(42.0));
121        assert_eq!(floor_f64(F64::from(-42.8)), F64::from(-43.0));
122
123        assert_eq!(round_f64(F64::from(42.4)), F64::from(42.0));
124        assert_eq!(round_f64(F64::from(42.5)), F64::from(42.0));
125        assert_eq!(round_f64(F64::from(-6.5)), F64::from(-6.0));
126        assert_eq!(round_f64(F64::from(43.5)), F64::from(44.0));
127        assert_eq!(round_f64(F64::from(-7.5)), F64::from(-8.0));
128    }
129
130    #[test]
131    fn test_round_decimal() {
132        assert_eq!(ceil_decimal(dec(42.2)), dec(43.0));
133        assert_eq!(ceil_decimal(dec(-42.8)), dec(-42.0));
134
135        assert_eq!(floor_decimal(dec(42.2)), dec(42.0));
136        assert_eq!(floor_decimal(dec(-42.8)), dec(-43.0));
137
138        assert_eq!(round_decimal(dec(42.4)), dec(42.0));
139        assert_eq!(round_decimal(dec(42.5)), dec(43.0));
140        assert_eq!(round_decimal(dec(-6.5)), dec(-7.0));
141    }
142
143    fn dec(f: f64) -> Decimal {
144        Decimal::try_from(f).unwrap()
145    }
146}