1use std::fmt::Debug;
16
17use chrono::{Duration, NaiveDateTime};
18use num_traits::{CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Zero};
19use risingwave_common::types::{
20 CheckedAdd, Date, Decimal, F64, FloatExt, Interval, IsNegative, Time, Timestamp,
21};
22use risingwave_expr::{ExprError, Result, function};
23use rust_decimal::MathematicalOps;
24
25#[function("add(*int, *int) -> auto")]
26#[function("add(decimal, decimal) -> auto")]
27#[function("add(*float, *float) -> auto")]
28#[function("add(interval, interval) -> interval")]
29#[function("add(int256, int256) -> int256")]
30pub fn general_add<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
31where
32 T1: Into<T3> + Debug,
33 T2: Into<T3> + Debug,
34 T3: CheckedAdd<Output = T3>,
35{
36 general_atm(l, r, |a, b| {
37 a.checked_add(b).ok_or(ExprError::NumericOutOfRange)
38 })
39}
40
41#[function("subtract(*int, *int) -> auto")]
42#[function("subtract(decimal, decimal) -> auto")]
43#[function("subtract(*float, *float) -> auto")]
44#[function("subtract(interval, interval) -> interval")]
45#[function("subtract(int256, int256) -> int256")]
46pub fn general_sub<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
47where
48 T1: Into<T3> + Debug,
49 T2: Into<T3> + Debug,
50 T3: CheckedSub,
51{
52 general_atm(l, r, |a, b| {
53 a.checked_sub(&b).ok_or(ExprError::NumericOutOfRange)
54 })
55}
56
57#[function("multiply(*int, *int) -> auto")]
58#[function("multiply(decimal, decimal) -> auto")]
59#[function("multiply(*float, *float) -> auto")]
60#[function("multiply(int256, int256) -> int256")]
61pub fn general_mul<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
62where
63 T1: Into<T3> + Debug,
64 T2: Into<T3> + Debug,
65 T3: CheckedMul,
66{
67 general_atm(l, r, |a, b| {
68 a.checked_mul(&b).ok_or(ExprError::NumericOutOfRange)
69 })
70}
71
72#[function("divide(*int, *int) -> auto")]
73#[function("divide(decimal, decimal) -> auto")]
74#[function("divide(*float, *float) -> auto")]
75#[function("divide(int256, int256) -> int256")]
76#[function("divide(int256, float8) -> float8")]
77#[function("divide(int256, *int) -> int256")]
78pub fn general_div<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
79where
80 T1: Into<T3> + Debug,
81 T2: Into<T3> + Debug,
82 T3: CheckedDiv + Zero,
83{
84 general_atm(l, r, |a, b| {
85 a.checked_div(&b).ok_or_else(|| {
86 if b.is_zero() {
87 ExprError::DivisionByZero
88 } else {
89 ExprError::NumericOutOfRange
90 }
91 })
92 })
93}
94
95#[function("modulus(*int, *int) -> auto")]
96#[function("modulus(decimal, decimal) -> auto")]
97#[function("modulus(int256, int256) -> int256")]
98pub fn general_mod<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
99where
100 T1: Into<T3> + Debug,
101 T2: Into<T3> + Debug,
102 T3: CheckedRem,
103{
104 general_atm(l, r, |a, b| {
105 a.checked_rem(&b).ok_or(ExprError::NumericOutOfRange)
106 })
107}
108
109#[function("neg(*int) -> auto")]
110#[function("neg(*float) -> auto")]
111#[function("neg(decimal) -> decimal")]
112pub fn general_neg<T1: CheckedNeg>(expr: T1) -> Result<T1> {
113 expr.checked_neg().ok_or(ExprError::NumericOutOfRange)
114}
115
116#[function("neg(int256) -> int256")]
117pub fn int256_neg<TRef, T>(expr: TRef) -> Result<T>
118where
119 TRef: Into<T> + Debug,
120 T: CheckedNeg + Debug,
121{
122 expr.into()
123 .checked_neg()
124 .ok_or(ExprError::NumericOutOfRange)
125}
126
127#[function("abs(*int) -> auto")]
128pub fn general_abs<T1: IsNegative + CheckedNeg>(expr: T1) -> Result<T1> {
129 if expr.is_negative() {
130 general_neg(expr)
131 } else {
132 Ok(expr)
133 }
134}
135
136#[function("abs(*float) -> auto")]
137pub fn float_abs<F: num_traits::Float, T1: FloatExt<F>>(expr: T1) -> T1 {
138 expr.abs()
139}
140
141#[function("abs(int256) -> int256")]
142pub fn int256_abs<TRef, T>(expr: TRef) -> Result<T>
143where
144 TRef: Into<T> + Debug,
145 T: IsNegative + CheckedNeg + Debug,
146{
147 let expr = expr.into();
148 if expr.is_negative() {
149 int256_neg(expr)
150 } else {
151 Ok(expr)
152 }
153}
154
155#[function("abs(decimal) -> decimal")]
156pub fn decimal_abs(decimal: Decimal) -> Decimal {
157 Decimal::abs(&decimal)
158}
159
160fn err_pow_zero_negative() -> ExprError {
161 ExprError::InvalidParam {
162 name: "rhs",
163 reason: "zero raised to a negative power is undefined".into(),
164 }
165}
166fn err_pow_negative_fract() -> ExprError {
167 ExprError::InvalidParam {
168 name: "rhs",
169 reason: "a negative number raised to a non-integer power yields a complex result".into(),
170 }
171}
172
173#[function("pow(float8, float8) -> float8")]
174pub fn pow_f64(l: F64, r: F64) -> Result<F64> {
175 if l.is_zero() && r.0 < 0.0 {
176 return Err(err_pow_zero_negative());
177 }
178 if l.0 < 0.0 && (r.is_finite() && !r.fract().is_zero()) {
179 return Err(err_pow_negative_fract());
180 }
181 let res = l.powf(r);
182 if res.is_infinite() && l.is_finite() && r.is_finite() {
183 return Err(ExprError::NumericOverflow);
184 }
185 if res.is_zero() && l.is_finite() && r.is_finite() && !l.is_zero() {
186 return Err(ExprError::NumericUnderflow);
187 }
188
189 Ok(res)
190}
191
192#[function("pow(decimal, decimal) -> decimal")]
193pub fn pow_decimal(l: Decimal, r: Decimal) -> Result<Decimal> {
194 use risingwave_common::types::DecimalPowError as PowError;
195
196 l.checked_powd(&r).map_err(|e| match e {
197 PowError::ZeroNegative => err_pow_zero_negative(),
198 PowError::NegativeFract => err_pow_negative_fract(),
199 PowError::Overflow => ExprError::NumericOverflow,
200 })
201}
202
203#[inline(always)]
204fn general_atm<T1, T2, T3, F>(l: T1, r: T2, atm: F) -> Result<T3>
205where
206 T1: Into<T3> + Debug,
207 T2: Into<T3> + Debug,
208 F: FnOnce(T3, T3) -> Result<T3>,
209{
210 atm(l.into(), r.into())
211}
212
213#[function("subtract(timestamp, timestamp) -> interval")]
214pub fn timestamp_timestamp_sub(l: Timestamp, r: Timestamp) -> Result<Interval> {
215 let tmp = l.0 - r.0; let days = tmp.num_days();
217 let usecs = (tmp - Duration::days(tmp.num_days()))
218 .num_microseconds()
219 .ok_or_else(|| ExprError::NumericOutOfRange)?;
220 Ok(Interval::from_month_day_usec(0, days as i32, usecs))
221}
222
223#[function("subtract(date, date) -> int4")]
224pub fn date_date_sub(l: Date, r: Date) -> Result<i32> {
225 Ok((l.0 - r.0).num_days() as i32) }
227
228#[function("add(interval, timestamp) -> timestamp")]
229pub fn interval_timestamp_add(l: Interval, r: Timestamp) -> Result<Timestamp> {
230 r.checked_add(l).ok_or(ExprError::NumericOutOfRange)
231}
232
233#[function("add(interval, date) -> timestamp")]
234pub fn interval_date_add(l: Interval, r: Date) -> Result<Timestamp> {
235 interval_timestamp_add(l, r.into())
236}
237
238#[function("add(interval, time) -> time")]
239pub fn interval_time_add(l: Interval, r: Time) -> Result<Time> {
240 time_interval_add(r, l)
241}
242
243#[function("add(date, interval) -> timestamp")]
244pub fn date_interval_add(l: Date, r: Interval) -> Result<Timestamp> {
245 interval_date_add(r, l)
246}
247
248#[function("subtract(date, interval) -> timestamp")]
249pub fn date_interval_sub(l: Date, r: Interval) -> Result<Timestamp> {
250 interval_date_add(r.checked_neg().ok_or(ExprError::NumericOutOfRange)?, l)
253}
254
255#[function("add(date, int4) -> date")]
256pub fn date_int_add(l: Date, r: i32) -> Result<Date> {
257 let date = l.0;
258 let date_wrapper = date
259 .checked_add_signed(chrono::Duration::days(r as i64))
260 .map(Date::new);
261
262 date_wrapper.ok_or(ExprError::NumericOutOfRange)
263}
264
265#[function("add(int4, date) -> date")]
266pub fn int_date_add(l: i32, r: Date) -> Result<Date> {
267 date_int_add(r, l)
268}
269
270#[function("subtract(date, int4) -> date")]
271pub fn date_int_sub(l: Date, r: i32) -> Result<Date> {
272 let date = l.0;
273 let date_wrapper = date
274 .checked_sub_signed(chrono::Duration::days(r as i64))
275 .map(Date::new);
276
277 date_wrapper.ok_or(ExprError::NumericOutOfRange)
278}
279
280#[function("add(timestamp, interval) -> timestamp")]
281pub fn timestamp_interval_add(l: Timestamp, r: Interval) -> Result<Timestamp> {
282 interval_timestamp_add(r, l)
283}
284
285#[function("subtract(timestamp, interval) -> timestamp")]
286pub fn timestamp_interval_sub(l: Timestamp, r: Interval) -> Result<Timestamp> {
287 interval_timestamp_add(r.checked_neg().ok_or(ExprError::NumericOutOfRange)?, l)
288}
289
290#[function("multiply(interval, *int) -> interval")]
291pub fn interval_int_mul(l: Interval, r: impl TryInto<i32> + Debug) -> Result<Interval> {
292 l.checked_mul_int(r).ok_or(ExprError::NumericOutOfRange)
293}
294
295#[function("multiply(*int, interval) -> interval")]
296pub fn int_interval_mul(l: impl TryInto<i32> + Debug, r: Interval) -> Result<Interval> {
297 interval_int_mul(r, l)
298}
299
300#[function("add(date, time) -> timestamp")]
301pub fn date_time_add(l: Date, r: Time) -> Result<Timestamp> {
302 Ok(Timestamp::new(NaiveDateTime::new(l.0, r.0)))
303}
304
305#[function("add(time, date) -> timestamp")]
306pub fn time_date_add(l: Time, r: Date) -> Result<Timestamp> {
307 date_time_add(r, l)
308}
309
310#[function("subtract(time, time) -> interval")]
311pub fn time_time_sub(l: Time, r: Time) -> Result<Interval> {
312 let tmp = l.0 - r.0; let usecs = tmp
314 .num_microseconds()
315 .ok_or_else(|| ExprError::NumericOutOfRange)?;
316 Ok(Interval::from_month_day_usec(0, 0, usecs))
317}
318
319#[function("subtract(time, interval) -> time")]
320pub fn time_interval_sub(l: Time, r: Interval) -> Result<Time> {
321 let time = l.0;
322 let (new_time, ignored) = time.overflowing_sub_signed(Duration::microseconds(r.usecs()));
323 if ignored == 0 {
324 Ok(Time::new(new_time))
325 } else {
326 Err(ExprError::NumericOutOfRange)
327 }
328}
329
330#[function("add(time, interval) -> time")]
331pub fn time_interval_add(l: Time, r: Interval) -> Result<Time> {
332 let time = l.0;
333 let (new_time, ignored) = time.overflowing_add_signed(Duration::microseconds(r.usecs()));
334 if ignored == 0 {
335 Ok(Time::new(new_time))
336 } else {
337 Err(ExprError::NumericOutOfRange)
338 }
339}
340
341#[function("divide(interval, *int) -> interval")]
342#[function("divide(interval, decimal) -> interval")]
343#[function("divide(interval, *float) -> interval")]
344pub fn interval_float_div<T2>(l: Interval, r: T2) -> Result<Interval>
345where
346 T2: TryInto<F64> + Debug,
347{
348 l.div_float(r).ok_or(ExprError::NumericOutOfRange)
349}
350
351#[function("multiply(interval, float4) -> interval")]
352#[function("multiply(interval, float8) -> interval")]
353#[function("multiply(interval, decimal) -> interval")]
354pub fn interval_float_mul<T2>(l: Interval, r: T2) -> Result<Interval>
355where
356 T2: TryInto<F64> + Debug,
357{
358 l.mul_float(r).ok_or(ExprError::NumericOutOfRange)
359}
360
361#[function("multiply(float4, interval) -> interval")]
362#[function("multiply(float8, interval) -> interval")]
363#[function("multiply(decimal, interval) -> interval")]
364pub fn float_interval_mul<T1>(l: T1, r: Interval) -> Result<Interval>
365where
366 T1: TryInto<F64> + Debug,
367{
368 r.mul_float(l).ok_or(ExprError::NumericOutOfRange)
369}
370
371#[function("sqrt(float8) -> float8")]
372pub fn sqrt_f64(expr: F64) -> Result<F64> {
373 if expr < F64::from(0.0) {
374 return Err(ExprError::InvalidParam {
375 name: "sqrt input",
376 reason: "input cannot be negative value".into(),
377 });
378 }
379 match expr.is_nan() || expr == f64::INFINITY || expr == -0.0 {
381 true => Ok(expr),
382 false => Ok(expr.sqrt()),
383 }
384}
385
386#[function("sqrt(decimal) -> decimal")]
387pub fn sqrt_decimal(expr: Decimal) -> Result<Decimal> {
388 match expr {
389 Decimal::NaN | Decimal::PositiveInf => Ok(expr),
390 Decimal::Normalized(value) => match value.sqrt() {
391 Some(res) => Ok(Decimal::from(res)),
392 None => Err(ExprError::InvalidParam {
393 name: "sqrt input",
394 reason: "input cannot be negative value".into(),
395 }),
396 },
397 Decimal::NegativeInf => Err(ExprError::InvalidParam {
398 name: "sqrt input",
399 reason: "input cannot be negative value".into(),
400 }),
401 }
402}
403
404#[function("cbrt(float8) -> float8")]
405pub fn cbrt_f64(expr: F64) -> F64 {
406 expr.cbrt()
407}
408
409#[function("sign(float8) -> float8")]
410pub fn sign_f64(input: F64) -> F64 {
411 match input.0.partial_cmp(&0.) {
412 Some(std::cmp::Ordering::Less) => (-1).into(),
413 Some(std::cmp::Ordering::Equal) => 0.into(),
414 Some(std::cmp::Ordering::Greater) => 1.into(),
415 None => 0.into(),
416 }
417}
418
419#[function("sign(decimal) -> decimal")]
420pub fn sign_dec(input: Decimal) -> Decimal {
421 input.sign()
422}
423
424#[function("scale(decimal) -> int4")]
425pub fn decimal_scale(d: Decimal) -> Option<i32> {
426 d.scale()
427}
428
429#[function("min_scale(decimal) -> int4")]
430pub fn decimal_min_scale(d: Decimal) -> Option<i32> {
431 d.normalize().scale()
432}
433
434#[function("trim_scale(decimal) -> decimal")]
435pub fn decimal_trim_scale(d: Decimal) -> Decimal {
436 d.normalize()
437}
438
439#[function("gamma(float8) -> float8")]
440pub fn gamma_f64(input: F64) -> Result<F64> {
441 let mut result = input;
442 if input.is_nan() {
443 return Ok(result);
444 } else if input.is_infinite() {
445 if input.is_negative() {
446 return Err(ExprError::NumericOverflow);
447 }
448 } else {
449 result = input.gamma();
450 if result.is_nan() || result.is_infinite() {
451 return Err(ExprError::NumericOverflow);
452 } else if result.is_zero() {
453 return Err(ExprError::NumericUnderflow);
454 }
455 }
456 Ok(result)
457}
458
459#[function("lgamma(float8) -> float8")]
460pub fn lgamma_f64(input: F64) -> Result<F64> {
461 let (result, _sign) = input.ln_gamma();
462 if result.is_infinite() && input.is_finite() {
463 return Err(ExprError::NumericOverflow);
464 }
465 Ok(F64::from(result))
466}
467
468#[cfg(test)]
469mod tests {
470 use std::str::FromStr;
471
472 use num_traits::Float;
473 use risingwave_common::types::test_utils::IntervalTestExt;
474 use risingwave_common::types::{F32, Int256, Int256Ref, Scalar};
475
476 use super::*;
477
478 #[test]
479 fn test() {
480 assert_eq!(
481 general_add::<_, _, Decimal>(Decimal::from_str("1").unwrap(), 1i32).unwrap(),
482 Decimal::from_str("2").unwrap()
483 );
484 }
485
486 #[test]
487 fn test_arithmetic() {
488 assert_eq!(
489 general_add::<Decimal, i32, Decimal>(dec("1.0"), 1).unwrap(),
490 dec("2.0")
491 );
492 assert_eq!(
493 general_sub::<Decimal, i32, Decimal>(dec("1.0"), 2).unwrap(),
494 dec("-1.0")
495 );
496 assert_eq!(
497 general_mul::<Decimal, i32, Decimal>(dec("1.0"), 2).unwrap(),
498 dec("2.0")
499 );
500 assert_eq!(
501 general_div::<Decimal, i32, Decimal>(dec("2.0"), 2).unwrap(),
502 dec("1.0")
503 );
504 assert_eq!(
505 general_mod::<Decimal, i32, Decimal>(dec("2.0"), 2).unwrap(),
506 dec("0")
507 );
508 assert_eq!(general_neg::<Decimal>(dec("1.0")).unwrap(), dec("-1.0"));
509 assert_eq!(general_add::<i16, i32, i32>(1i16, 1i32).unwrap(), 2i32);
510 assert_eq!(general_sub::<i16, i32, i32>(1i16, 1i32).unwrap(), 0i32);
511 assert_eq!(general_mul::<i16, i32, i32>(1i16, 1i32).unwrap(), 1i32);
512 assert_eq!(general_div::<i16, i32, i32>(1i16, 1i32).unwrap(), 1i32);
513 assert_eq!(general_mod::<i16, i32, i32>(1i16, 1i32).unwrap(), 0i32);
514 assert_eq!(general_neg::<i16>(1i16).unwrap(), -1i16);
515
516 assert!(
517 general_add::<i32, F32, F64>(-1i32, 1f32.into())
518 .unwrap()
519 .is_zero()
520 );
521 assert!(
522 general_sub::<i32, F32, F64>(1i32, 1f32.into())
523 .unwrap()
524 .is_zero()
525 );
526 assert!(
527 general_mul::<i32, F32, F64>(0i32, 1f32.into())
528 .unwrap()
529 .is_zero()
530 );
531 assert!(
532 general_div::<i32, F32, F64>(0i32, 1f32.into())
533 .unwrap()
534 .is_zero()
535 );
536 assert_eq!(general_neg::<F32>(1f32.into()).unwrap(), F32::from(-1f32));
537 assert_eq!(
538 date_interval_add(Date::from_ymd_uncheck(1994, 1, 1), Interval::from_month(12))
539 .unwrap(),
540 Timestamp::new(
541 NaiveDateTime::parse_from_str("1995-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
542 )
543 );
544 assert_eq!(
545 interval_date_add(Interval::from_month(12), Date::from_ymd_uncheck(1994, 1, 1))
546 .unwrap(),
547 Timestamp::new(
548 NaiveDateTime::parse_from_str("1995-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
549 )
550 );
551 assert_eq!(
552 date_interval_sub(Date::from_ymd_uncheck(1994, 1, 1), Interval::from_month(12))
553 .unwrap(),
554 Timestamp::new(
555 NaiveDateTime::parse_from_str("1993-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
556 )
557 );
558 assert_eq!(sqrt_f64(F64::from(25.00)).unwrap(), F64::from(5.0));
559 assert_eq!(
560 sqrt_f64(F64::from(107)).unwrap(),
561 F64::from(10.344080432788601)
562 );
563 assert_eq!(
564 sqrt_f64(F64::from(12.234567)).unwrap(),
565 F64::from(3.4977945908815173)
566 );
567 assert!(sqrt_f64(F64::from(-25.00)).is_err());
568 assert_eq!(sqrt_f64(F64::from(f64::NAN)).unwrap(), F64::from(f64::NAN));
570 assert_eq!(
571 sqrt_f64(F64::from(f64::neg_zero())).unwrap(),
572 F64::from(f64::neg_zero())
573 );
574 assert_eq!(
575 sqrt_f64(F64::from(f64::INFINITY)).unwrap(),
576 F64::from(f64::INFINITY)
577 );
578 assert!(sqrt_f64(F64::from(f64::NEG_INFINITY)).is_err());
579 assert_eq!(sqrt_decimal(dec("25.0")).unwrap(), dec("5.0"));
580 assert_eq!(
581 sqrt_decimal(dec("107")).unwrap(),
582 dec("10.344080432788600469738599442")
583 );
584 assert_eq!(
585 sqrt_decimal(dec("12.234567")).unwrap(),
586 dec("3.4977945908815171589625746860")
587 );
588 assert!(sqrt_decimal(dec("-25.0")).is_err());
589 assert_eq!(sqrt_decimal(dec("nan")).unwrap(), dec("nan"));
590 assert_eq!(sqrt_decimal(dec("inf")).unwrap(), dec("inf"));
591 assert_eq!(sqrt_decimal(dec("-0")).unwrap(), dec("-0"));
592 assert!(sqrt_decimal(dec("-inf")).is_err());
593 }
594
595 #[test]
596 fn test_arithmetic_int256() {
597 let tuples = vec![
598 (0, 1, "0"),
599 (0, -1, "0"),
600 (1, 1, "1"),
601 (1, -1, "-1"),
602 (1, 2, "0.5"),
603 (1, -2, "-0.5"),
604 (9007199254740991i64, 2, "4503599627370495.5"),
605 ];
606
607 for (i, j, k) in tuples {
608 let lhs = Int256::from(i);
609 let rhs = F64::from(j);
610 let res = F64::from_str(k).unwrap();
611 assert_eq!(
612 general_div::<Int256Ref<'_>, F64, F64>(lhs.as_scalar_ref(), rhs).unwrap(),
613 res,
614 );
615 }
616 }
617
618 fn dec(s: &str) -> Decimal {
619 Decimal::from_str(s).unwrap()
620 }
621}