risingwave_expr_impl/scalar/
exp.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use num_traits::Zero;
16use risingwave_common::types::{Decimal, F64, FloatExt};
17use risingwave_expr::{ExprError, Result, function};
18
19fn err_logarithm_input() -> ExprError {
20    ExprError::InvalidParam {
21        name: "input",
22        reason: "cannot take logarithm of zero or a negative number".into(),
23    }
24}
25
26#[function("exp(float8) -> float8")]
27pub fn exp_f64(input: F64) -> Result<F64> {
28    // The cases where the exponent value is Inf or NaN can be handled explicitly and without
29    // evaluating the `exp` operation.
30    if input.is_nan() {
31        Ok(input)
32    } else if input.is_infinite() {
33        if input.is_sign_negative() {
34            Ok(0.into())
35        } else {
36            Ok(input)
37        }
38    } else {
39        let res = input.exp();
40
41        // If the argument passed to `exp` is not `inf` or `-inf` then a result that is `inf` or `0`
42        // means that the operation had an overflow or an underflow, and the appropriate
43        // error should be returned.
44        if res.is_infinite() {
45            Err(ExprError::NumericOverflow)
46        } else if res.is_zero() {
47            Err(ExprError::NumericUnderflow)
48        } else {
49            Ok(res)
50        }
51    }
52}
53
54#[function("ln(float8) -> float8")]
55pub fn ln_f64(input: F64) -> Result<F64> {
56    if input.0 <= 0.0 {
57        return Err(err_logarithm_input());
58    }
59    Ok(input.ln())
60}
61
62#[function("log10(float8) -> float8")]
63pub fn log10_f64(input: F64) -> Result<F64> {
64    if input.0 <= 0.0 {
65        return Err(err_logarithm_input());
66    }
67    Ok(input.log10())
68}
69
70#[function("exp(decimal) -> decimal")]
71pub fn exp_decimal(input: Decimal) -> Result<Decimal> {
72    input.checked_exp().ok_or(ExprError::NumericOverflow)
73}
74
75#[function("ln(decimal) -> decimal")]
76pub fn ln_decimal(input: Decimal) -> Result<Decimal> {
77    input.checked_ln().ok_or_else(err_logarithm_input)
78}
79
80#[function("log10(decimal) -> decimal")]
81pub fn log10_decimal(input: Decimal) -> Result<Decimal> {
82    input.checked_log10().ok_or_else(err_logarithm_input)
83}
84
85#[cfg(test)]
86mod tests {
87    use risingwave_common::types::F64;
88    use risingwave_expr::ExprError;
89
90    use super::exp_f64;
91
92    #[test]
93    fn legal_input() {
94        let res = exp_f64(0.0.into()).unwrap();
95        assert_eq!(res, F64::from(1.0));
96    }
97
98    #[test]
99    fn underflow() {
100        let res = exp_f64((-1000.0).into()).unwrap_err();
101        match res {
102            ExprError::NumericUnderflow => (),
103            _ => panic!("Expected ExprError::FloatUnderflow"),
104        }
105    }
106
107    #[test]
108    fn overflow() {
109        let res = exp_f64(1000.0.into()).unwrap_err();
110        match res {
111            ExprError::NumericOverflow => (),
112            _ => panic!("Expected ExprError::FloatUnderflow"),
113        }
114    }
115
116    #[test]
117    fn nan() {
118        let res = exp_f64(f64::NAN.into()).unwrap();
119        assert_eq!(res, F64::from(f64::NAN));
120
121        let res = exp_f64((-f64::NAN).into()).unwrap();
122        assert_eq!(res, F64::from(-f64::NAN));
123    }
124
125    #[test]
126    fn infinity() {
127        let res = exp_f64(f64::INFINITY.into()).unwrap();
128        assert_eq!(res, F64::from(f64::INFINITY));
129
130        let res = exp_f64(f64::NEG_INFINITY.into()).unwrap();
131        assert_eq!(res, F64::from(0.0));
132    }
133}