risingwave_expr_impl/scalar/
arithmetic_op.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16
17use chrono::{Duration, NaiveDateTime};
18use num_traits::{CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Zero};
19use risingwave_common::types::{
20    CheckedAdd, Date, Decimal, F64, FloatExt, Interval, IsNegative, Time, Timestamp,
21};
22use risingwave_expr::{ExprError, Result, function};
23use rust_decimal::MathematicalOps;
24
25#[function("add(*int, *int) -> auto")]
26#[function("add(decimal, decimal) -> auto")]
27#[function("add(*float, *float) -> auto")]
28#[function("add(interval, interval) -> interval")]
29#[function("add(int256, int256) -> int256")]
30pub fn general_add<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
31where
32    T1: Into<T3> + Debug,
33    T2: Into<T3> + Debug,
34    T3: CheckedAdd<Output = T3>,
35{
36    general_atm(l, r, |a, b| {
37        a.checked_add(b).ok_or(ExprError::NumericOutOfRange)
38    })
39}
40
41#[function("subtract(*int, *int) -> auto")]
42#[function("subtract(decimal, decimal) -> auto")]
43#[function("subtract(*float, *float) -> auto")]
44#[function("subtract(interval, interval) -> interval")]
45#[function("subtract(int256, int256) -> int256")]
46pub fn general_sub<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
47where
48    T1: Into<T3> + Debug,
49    T2: Into<T3> + Debug,
50    T3: CheckedSub,
51{
52    general_atm(l, r, |a, b| {
53        a.checked_sub(&b).ok_or(ExprError::NumericOutOfRange)
54    })
55}
56
57#[function("multiply(*int, *int) -> auto")]
58#[function("multiply(decimal, decimal) -> auto")]
59#[function("multiply(*float, *float) -> auto")]
60#[function("multiply(int256, int256) -> int256")]
61pub fn general_mul<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
62where
63    T1: Into<T3> + Debug,
64    T2: Into<T3> + Debug,
65    T3: CheckedMul,
66{
67    general_atm(l, r, |a, b| {
68        a.checked_mul(&b).ok_or(ExprError::NumericOutOfRange)
69    })
70}
71
72#[function("divide(*int, *int) -> auto")]
73#[function("divide(decimal, decimal) -> auto")]
74#[function("divide(*float, *float) -> auto")]
75#[function("divide(int256, int256) -> int256")]
76#[function("divide(int256, float8) -> float8")]
77#[function("divide(int256, *int) -> int256")]
78pub fn general_div<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
79where
80    T1: Into<T3> + Debug,
81    T2: Into<T3> + Debug,
82    T3: CheckedDiv + Zero,
83{
84    general_atm(l, r, |a, b| {
85        a.checked_div(&b).ok_or_else(|| {
86            if b.is_zero() {
87                ExprError::DivisionByZero
88            } else {
89                ExprError::NumericOutOfRange
90            }
91        })
92    })
93}
94
95#[function("modulus(*int, *int) -> auto")]
96#[function("modulus(decimal, decimal) -> auto")]
97#[function("modulus(int256, int256) -> int256")]
98pub fn general_mod<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
99where
100    T1: Into<T3> + Debug,
101    T2: Into<T3> + Debug,
102    T3: CheckedRem,
103{
104    general_atm(l, r, |a, b| {
105        a.checked_rem(&b).ok_or(ExprError::NumericOutOfRange)
106    })
107}
108
109#[function("neg(*int) -> auto")]
110#[function("neg(*float) -> auto")]
111#[function("neg(decimal) -> decimal")]
112pub fn general_neg<T1: CheckedNeg>(expr: T1) -> Result<T1> {
113    expr.checked_neg().ok_or(ExprError::NumericOutOfRange)
114}
115
116#[function("neg(int256) -> int256")]
117pub fn int256_neg<TRef, T>(expr: TRef) -> Result<T>
118where
119    TRef: Into<T> + Debug,
120    T: CheckedNeg + Debug,
121{
122    expr.into()
123        .checked_neg()
124        .ok_or(ExprError::NumericOutOfRange)
125}
126
127#[function("abs(*int) -> auto")]
128pub fn general_abs<T1: IsNegative + CheckedNeg>(expr: T1) -> Result<T1> {
129    if expr.is_negative() {
130        general_neg(expr)
131    } else {
132        Ok(expr)
133    }
134}
135
136#[function("abs(*float) -> auto")]
137pub fn float_abs<F: num_traits::Float, T1: FloatExt<F>>(expr: T1) -> T1 {
138    expr.abs()
139}
140
141#[function("abs(int256) -> int256")]
142pub fn int256_abs<TRef, T>(expr: TRef) -> Result<T>
143where
144    TRef: Into<T> + Debug,
145    T: IsNegative + CheckedNeg + Debug,
146{
147    let expr = expr.into();
148    if expr.is_negative() {
149        int256_neg(expr)
150    } else {
151        Ok(expr)
152    }
153}
154
155#[function("abs(decimal) -> decimal")]
156pub fn decimal_abs(decimal: Decimal) -> Decimal {
157    Decimal::abs(&decimal)
158}
159
160fn err_pow_zero_negative() -> ExprError {
161    ExprError::InvalidParam {
162        name: "rhs",
163        reason: "zero raised to a negative power is undefined".into(),
164    }
165}
166fn err_pow_negative_fract() -> ExprError {
167    ExprError::InvalidParam {
168        name: "rhs",
169        reason: "a negative number raised to a non-integer power yields a complex result".into(),
170    }
171}
172
173#[function("pow(float8, float8) -> float8")]
174pub fn pow_f64(l: F64, r: F64) -> Result<F64> {
175    if l.is_zero() && r.0 < 0.0 {
176        return Err(err_pow_zero_negative());
177    }
178    if l.0 < 0.0 && (r.is_finite() && !r.fract().is_zero()) {
179        return Err(err_pow_negative_fract());
180    }
181    let res = l.powf(r);
182    if res.is_infinite() && l.is_finite() && r.is_finite() {
183        return Err(ExprError::NumericOverflow);
184    }
185    if res.is_zero() && l.is_finite() && r.is_finite() && !l.is_zero() {
186        return Err(ExprError::NumericUnderflow);
187    }
188
189    Ok(res)
190}
191
192#[function("pow(decimal, decimal) -> decimal")]
193pub fn pow_decimal(l: Decimal, r: Decimal) -> Result<Decimal> {
194    use risingwave_common::types::DecimalPowError as PowError;
195
196    l.checked_powd(&r).map_err(|e| match e {
197        PowError::ZeroNegative => err_pow_zero_negative(),
198        PowError::NegativeFract => err_pow_negative_fract(),
199        PowError::Overflow => ExprError::NumericOverflow,
200    })
201}
202
203#[inline(always)]
204fn general_atm<T1, T2, T3, F>(l: T1, r: T2, atm: F) -> Result<T3>
205where
206    T1: Into<T3> + Debug,
207    T2: Into<T3> + Debug,
208    F: FnOnce(T3, T3) -> Result<T3>,
209{
210    atm(l.into(), r.into())
211}
212
213#[function("subtract(timestamp, timestamp) -> interval")]
214pub fn timestamp_timestamp_sub(l: Timestamp, r: Timestamp) -> Result<Interval> {
215    let tmp = l.0 - r.0; // this does not overflow or underflow
216    let days = tmp.num_days();
217    let usecs = (tmp - Duration::days(tmp.num_days()))
218        .num_microseconds()
219        .ok_or_else(|| ExprError::NumericOutOfRange)?;
220    Ok(Interval::from_month_day_usec(0, days as i32, usecs))
221}
222
223#[function("subtract(date, date) -> int4")]
224pub fn date_date_sub(l: Date, r: Date) -> Result<i32> {
225    Ok((l.0 - r.0).num_days() as i32) // this does not overflow or underflow
226}
227
228#[function("add(interval, timestamp) -> timestamp")]
229pub fn interval_timestamp_add(l: Interval, r: Timestamp) -> Result<Timestamp> {
230    r.checked_add(l).ok_or(ExprError::NumericOutOfRange)
231}
232
233#[function("add(interval, date) -> timestamp")]
234pub fn interval_date_add(l: Interval, r: Date) -> Result<Timestamp> {
235    interval_timestamp_add(l, r.into())
236}
237
238#[function("add(interval, time) -> time")]
239pub fn interval_time_add(l: Interval, r: Time) -> Result<Time> {
240    time_interval_add(r, l)
241}
242
243#[function("add(date, interval) -> timestamp")]
244pub fn date_interval_add(l: Date, r: Interval) -> Result<Timestamp> {
245    interval_date_add(r, l)
246}
247
248#[function("subtract(date, interval) -> timestamp")]
249pub fn date_interval_sub(l: Date, r: Interval) -> Result<Timestamp> {
250    // TODO: implement `checked_sub` for `Timestamp` to handle the edge case of negation
251    // overflowing.
252    interval_date_add(r.checked_neg().ok_or(ExprError::NumericOutOfRange)?, l)
253}
254
255#[function("add(date, int4) -> date")]
256pub fn date_int_add(l: Date, r: i32) -> Result<Date> {
257    let date = l.0;
258    let date_wrapper = date
259        .checked_add_signed(chrono::Duration::days(r as i64))
260        .map(Date::new);
261
262    date_wrapper.ok_or(ExprError::NumericOutOfRange)
263}
264
265#[function("add(int4, date) -> date")]
266pub fn int_date_add(l: i32, r: Date) -> Result<Date> {
267    date_int_add(r, l)
268}
269
270#[function("subtract(date, int4) -> date")]
271pub fn date_int_sub(l: Date, r: i32) -> Result<Date> {
272    let date = l.0;
273    let date_wrapper = date
274        .checked_sub_signed(chrono::Duration::days(r as i64))
275        .map(Date::new);
276
277    date_wrapper.ok_or(ExprError::NumericOutOfRange)
278}
279
280#[function("add(timestamp, interval) -> timestamp")]
281pub fn timestamp_interval_add(l: Timestamp, r: Interval) -> Result<Timestamp> {
282    interval_timestamp_add(r, l)
283}
284
285#[function("subtract(timestamp, interval) -> timestamp")]
286pub fn timestamp_interval_sub(l: Timestamp, r: Interval) -> Result<Timestamp> {
287    interval_timestamp_add(r.checked_neg().ok_or(ExprError::NumericOutOfRange)?, l)
288}
289
290#[function("multiply(interval, *int) -> interval")]
291pub fn interval_int_mul(l: Interval, r: impl TryInto<i32> + Debug) -> Result<Interval> {
292    l.checked_mul_int(r).ok_or(ExprError::NumericOutOfRange)
293}
294
295#[function("multiply(*int, interval) -> interval")]
296pub fn int_interval_mul(l: impl TryInto<i32> + Debug, r: Interval) -> Result<Interval> {
297    interval_int_mul(r, l)
298}
299
300#[function("add(date, time) -> timestamp")]
301pub fn date_time_add(l: Date, r: Time) -> Result<Timestamp> {
302    Ok(Timestamp::new(NaiveDateTime::new(l.0, r.0)))
303}
304
305#[function("add(time, date) -> timestamp")]
306pub fn time_date_add(l: Time, r: Date) -> Result<Timestamp> {
307    date_time_add(r, l)
308}
309
310#[function("subtract(time, time) -> interval")]
311pub fn time_time_sub(l: Time, r: Time) -> Result<Interval> {
312    let tmp = l.0 - r.0; // this does not overflow or underflow
313    let usecs = tmp
314        .num_microseconds()
315        .ok_or_else(|| ExprError::NumericOutOfRange)?;
316    Ok(Interval::from_month_day_usec(0, 0, usecs))
317}
318
319#[function("subtract(time, interval) -> time")]
320pub fn time_interval_sub(l: Time, r: Interval) -> Result<Time> {
321    let time = l.0;
322    let (new_time, ignored) = time.overflowing_sub_signed(Duration::microseconds(r.usecs()));
323    if ignored == 0 {
324        Ok(Time::new(new_time))
325    } else {
326        Err(ExprError::NumericOutOfRange)
327    }
328}
329
330#[function("add(time, interval) -> time")]
331pub fn time_interval_add(l: Time, r: Interval) -> Result<Time> {
332    let time = l.0;
333    let (new_time, ignored) = time.overflowing_add_signed(Duration::microseconds(r.usecs()));
334    if ignored == 0 {
335        Ok(Time::new(new_time))
336    } else {
337        Err(ExprError::NumericOutOfRange)
338    }
339}
340
341#[function("divide(interval, *int) -> interval")]
342#[function("divide(interval, decimal) -> interval")]
343#[function("divide(interval, *float) -> interval")]
344pub fn interval_float_div<T2>(l: Interval, r: T2) -> Result<Interval>
345where
346    T2: TryInto<F64> + Debug,
347{
348    l.div_float(r).ok_or(ExprError::NumericOutOfRange)
349}
350
351#[function("multiply(interval, float4) -> interval")]
352#[function("multiply(interval, float8) -> interval")]
353#[function("multiply(interval, decimal) -> interval")]
354pub fn interval_float_mul<T2>(l: Interval, r: T2) -> Result<Interval>
355where
356    T2: TryInto<F64> + Debug,
357{
358    l.mul_float(r).ok_or(ExprError::NumericOutOfRange)
359}
360
361#[function("multiply(float4, interval) -> interval")]
362#[function("multiply(float8, interval) -> interval")]
363#[function("multiply(decimal, interval) -> interval")]
364pub fn float_interval_mul<T1>(l: T1, r: Interval) -> Result<Interval>
365where
366    T1: TryInto<F64> + Debug,
367{
368    r.mul_float(l).ok_or(ExprError::NumericOutOfRange)
369}
370
371#[function("sqrt(float8) -> float8")]
372pub fn sqrt_f64(expr: F64) -> Result<F64> {
373    if expr < F64::from(0.0) {
374        return Err(ExprError::InvalidParam {
375            name: "sqrt input",
376            reason: "input cannot be negative value".into(),
377        });
378    }
379    // Edge cases: nan, inf, negative zero should return itself.
380    match expr.is_nan() || expr == f64::INFINITY || expr == -0.0 {
381        true => Ok(expr),
382        false => Ok(expr.sqrt()),
383    }
384}
385
386#[function("sqrt(decimal) -> decimal")]
387pub fn sqrt_decimal(expr: Decimal) -> Result<Decimal> {
388    match expr {
389        Decimal::NaN | Decimal::PositiveInf => Ok(expr),
390        Decimal::Normalized(value) => match value.sqrt() {
391            Some(res) => Ok(Decimal::from(res)),
392            None => Err(ExprError::InvalidParam {
393                name: "sqrt input",
394                reason: "input cannot be negative value".into(),
395            }),
396        },
397        Decimal::NegativeInf => Err(ExprError::InvalidParam {
398            name: "sqrt input",
399            reason: "input cannot be negative value".into(),
400        }),
401    }
402}
403
404#[function("cbrt(float8) -> float8")]
405pub fn cbrt_f64(expr: F64) -> F64 {
406    expr.cbrt()
407}
408
409#[function("sign(float8) -> float8")]
410pub fn sign_f64(input: F64) -> F64 {
411    match input.0.partial_cmp(&0.) {
412        Some(std::cmp::Ordering::Less) => (-1).into(),
413        Some(std::cmp::Ordering::Equal) => 0.into(),
414        Some(std::cmp::Ordering::Greater) => 1.into(),
415        None => 0.into(),
416    }
417}
418
419#[function("sign(decimal) -> decimal")]
420pub fn sign_dec(input: Decimal) -> Decimal {
421    input.sign()
422}
423
424#[function("scale(decimal) -> int4")]
425pub fn decimal_scale(d: Decimal) -> Option<i32> {
426    d.scale()
427}
428
429#[function("min_scale(decimal) -> int4")]
430pub fn decimal_min_scale(d: Decimal) -> Option<i32> {
431    d.normalize().scale()
432}
433
434#[function("trim_scale(decimal) -> decimal")]
435pub fn decimal_trim_scale(d: Decimal) -> Decimal {
436    d.normalize()
437}
438
439#[cfg(test)]
440mod tests {
441    use std::str::FromStr;
442
443    use num_traits::Float;
444    use risingwave_common::types::test_utils::IntervalTestExt;
445    use risingwave_common::types::{F32, Int256, Int256Ref, Scalar};
446
447    use super::*;
448
449    #[test]
450    fn test() {
451        assert_eq!(
452            general_add::<_, _, Decimal>(Decimal::from_str("1").unwrap(), 1i32).unwrap(),
453            Decimal::from_str("2").unwrap()
454        );
455    }
456
457    #[test]
458    fn test_arithmetic() {
459        assert_eq!(
460            general_add::<Decimal, i32, Decimal>(dec("1.0"), 1).unwrap(),
461            dec("2.0")
462        );
463        assert_eq!(
464            general_sub::<Decimal, i32, Decimal>(dec("1.0"), 2).unwrap(),
465            dec("-1.0")
466        );
467        assert_eq!(
468            general_mul::<Decimal, i32, Decimal>(dec("1.0"), 2).unwrap(),
469            dec("2.0")
470        );
471        assert_eq!(
472            general_div::<Decimal, i32, Decimal>(dec("2.0"), 2).unwrap(),
473            dec("1.0")
474        );
475        assert_eq!(
476            general_mod::<Decimal, i32, Decimal>(dec("2.0"), 2).unwrap(),
477            dec("0")
478        );
479        assert_eq!(general_neg::<Decimal>(dec("1.0")).unwrap(), dec("-1.0"));
480        assert_eq!(general_add::<i16, i32, i32>(1i16, 1i32).unwrap(), 2i32);
481        assert_eq!(general_sub::<i16, i32, i32>(1i16, 1i32).unwrap(), 0i32);
482        assert_eq!(general_mul::<i16, i32, i32>(1i16, 1i32).unwrap(), 1i32);
483        assert_eq!(general_div::<i16, i32, i32>(1i16, 1i32).unwrap(), 1i32);
484        assert_eq!(general_mod::<i16, i32, i32>(1i16, 1i32).unwrap(), 0i32);
485        assert_eq!(general_neg::<i16>(1i16).unwrap(), -1i16);
486
487        assert!(
488            general_add::<i32, F32, F64>(-1i32, 1f32.into())
489                .unwrap()
490                .is_zero()
491        );
492        assert!(
493            general_sub::<i32, F32, F64>(1i32, 1f32.into())
494                .unwrap()
495                .is_zero()
496        );
497        assert!(
498            general_mul::<i32, F32, F64>(0i32, 1f32.into())
499                .unwrap()
500                .is_zero()
501        );
502        assert!(
503            general_div::<i32, F32, F64>(0i32, 1f32.into())
504                .unwrap()
505                .is_zero()
506        );
507        assert_eq!(general_neg::<F32>(1f32.into()).unwrap(), F32::from(-1f32));
508        assert_eq!(
509            date_interval_add(Date::from_ymd_uncheck(1994, 1, 1), Interval::from_month(12))
510                .unwrap(),
511            Timestamp::new(
512                NaiveDateTime::parse_from_str("1995-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
513            )
514        );
515        assert_eq!(
516            interval_date_add(Interval::from_month(12), Date::from_ymd_uncheck(1994, 1, 1))
517                .unwrap(),
518            Timestamp::new(
519                NaiveDateTime::parse_from_str("1995-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
520            )
521        );
522        assert_eq!(
523            date_interval_sub(Date::from_ymd_uncheck(1994, 1, 1), Interval::from_month(12))
524                .unwrap(),
525            Timestamp::new(
526                NaiveDateTime::parse_from_str("1993-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
527            )
528        );
529        assert_eq!(sqrt_f64(F64::from(25.00)).unwrap(), F64::from(5.0));
530        assert_eq!(
531            sqrt_f64(F64::from(107)).unwrap(),
532            F64::from(10.344080432788601)
533        );
534        assert_eq!(
535            sqrt_f64(F64::from(12.234567)).unwrap(),
536            F64::from(3.4977945908815173)
537        );
538        assert!(sqrt_f64(F64::from(-25.00)).is_err());
539        // sqrt edge cases.
540        assert_eq!(sqrt_f64(F64::from(f64::NAN)).unwrap(), F64::from(f64::NAN));
541        assert_eq!(
542            sqrt_f64(F64::from(f64::neg_zero())).unwrap(),
543            F64::from(f64::neg_zero())
544        );
545        assert_eq!(
546            sqrt_f64(F64::from(f64::INFINITY)).unwrap(),
547            F64::from(f64::INFINITY)
548        );
549        assert!(sqrt_f64(F64::from(f64::NEG_INFINITY)).is_err());
550        assert_eq!(sqrt_decimal(dec("25.0")).unwrap(), dec("5.0"));
551        assert_eq!(
552            sqrt_decimal(dec("107")).unwrap(),
553            dec("10.344080432788600469738599442")
554        );
555        assert_eq!(
556            sqrt_decimal(dec("12.234567")).unwrap(),
557            dec("3.4977945908815171589625746860")
558        );
559        assert!(sqrt_decimal(dec("-25.0")).is_err());
560        assert_eq!(sqrt_decimal(dec("nan")).unwrap(), dec("nan"));
561        assert_eq!(sqrt_decimal(dec("inf")).unwrap(), dec("inf"));
562        assert_eq!(sqrt_decimal(dec("-0")).unwrap(), dec("-0"));
563        assert!(sqrt_decimal(dec("-inf")).is_err());
564    }
565
566    #[test]
567    fn test_arithmetic_int256() {
568        let tuples = vec![
569            (0, 1, "0"),
570            (0, -1, "0"),
571            (1, 1, "1"),
572            (1, -1, "-1"),
573            (1, 2, "0.5"),
574            (1, -2, "-0.5"),
575            (9007199254740991i64, 2, "4503599627370495.5"),
576        ];
577
578        for (i, j, k) in tuples {
579            let lhs = Int256::from(i);
580            let rhs = F64::from(j);
581            let res = F64::from_str(k).unwrap();
582            assert_eq!(
583                general_div::<Int256Ref<'_>, F64, F64>(lhs.as_scalar_ref(), rhs).unwrap(),
584                res,
585            );
586        }
587    }
588
589    fn dec(s: &str) -> Decimal {
590        Decimal::from_str(s).unwrap()
591    }
592}