1use std::fmt::Debug;
16
17use chrono::{Duration, NaiveDateTime};
18use num_traits::{CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Zero};
19use risingwave_common::types::{
20 CheckedAdd, Date, Decimal, F64, FloatExt, Interval, IsNegative, Time, Timestamp,
21};
22use risingwave_expr::{ExprError, Result, function};
23use rust_decimal::MathematicalOps;
24
25#[function("add(*int, *int) -> auto")]
26#[function("add(decimal, decimal) -> auto")]
27#[function("add(*float, *float) -> auto")]
28#[function("add(interval, interval) -> interval")]
29#[function("add(int256, int256) -> int256")]
30pub fn general_add<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
31where
32 T1: Into<T3> + Debug,
33 T2: Into<T3> + Debug,
34 T3: CheckedAdd<Output = T3>,
35{
36 general_atm(l, r, |a, b| {
37 a.checked_add(b).ok_or(ExprError::NumericOutOfRange)
38 })
39}
40
41#[function("subtract(*int, *int) -> auto")]
42#[function("subtract(decimal, decimal) -> auto")]
43#[function("subtract(*float, *float) -> auto")]
44#[function("subtract(interval, interval) -> interval")]
45#[function("subtract(int256, int256) -> int256")]
46pub fn general_sub<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
47where
48 T1: Into<T3> + Debug,
49 T2: Into<T3> + Debug,
50 T3: CheckedSub,
51{
52 general_atm(l, r, |a, b| {
53 a.checked_sub(&b).ok_or(ExprError::NumericOutOfRange)
54 })
55}
56
57#[function("multiply(*int, *int) -> auto")]
58#[function("multiply(decimal, decimal) -> auto")]
59#[function("multiply(*float, *float) -> auto")]
60#[function("multiply(int256, int256) -> int256")]
61pub fn general_mul<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
62where
63 T1: Into<T3> + Debug,
64 T2: Into<T3> + Debug,
65 T3: CheckedMul,
66{
67 general_atm(l, r, |a, b| {
68 a.checked_mul(&b).ok_or(ExprError::NumericOutOfRange)
69 })
70}
71
72#[function("divide(*int, *int) -> auto")]
73#[function("divide(decimal, decimal) -> auto")]
74#[function("divide(*float, *float) -> auto")]
75#[function("divide(int256, int256) -> int256")]
76#[function("divide(int256, float8) -> float8")]
77#[function("divide(int256, *int) -> int256")]
78pub fn general_div<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
79where
80 T1: Into<T3> + Debug,
81 T2: Into<T3> + Debug,
82 T3: CheckedDiv + Zero,
83{
84 general_atm(l, r, |a, b| {
85 a.checked_div(&b).ok_or_else(|| {
86 if b.is_zero() {
87 ExprError::DivisionByZero
88 } else {
89 ExprError::NumericOutOfRange
90 }
91 })
92 })
93}
94
95#[function("modulus(*int, *int) -> auto")]
96#[function("modulus(decimal, decimal) -> auto")]
97#[function("modulus(int256, int256) -> int256")]
98pub fn general_mod<T1, T2, T3>(l: T1, r: T2) -> Result<T3>
99where
100 T1: Into<T3> + Debug,
101 T2: Into<T3> + Debug,
102 T3: CheckedRem,
103{
104 general_atm(l, r, |a, b| {
105 a.checked_rem(&b).ok_or(ExprError::NumericOutOfRange)
106 })
107}
108
109#[function("neg(*int) -> auto")]
110#[function("neg(*float) -> auto")]
111#[function("neg(decimal) -> decimal")]
112pub fn general_neg<T1: CheckedNeg>(expr: T1) -> Result<T1> {
113 expr.checked_neg().ok_or(ExprError::NumericOutOfRange)
114}
115
116#[function("neg(int256) -> int256")]
117pub fn int256_neg<TRef, T>(expr: TRef) -> Result<T>
118where
119 TRef: Into<T> + Debug,
120 T: CheckedNeg + Debug,
121{
122 expr.into()
123 .checked_neg()
124 .ok_or(ExprError::NumericOutOfRange)
125}
126
127#[function("abs(*int) -> auto")]
128pub fn general_abs<T1: IsNegative + CheckedNeg>(expr: T1) -> Result<T1> {
129 if expr.is_negative() {
130 general_neg(expr)
131 } else {
132 Ok(expr)
133 }
134}
135
136#[function("abs(*float) -> auto")]
137pub fn float_abs<F: num_traits::Float, T1: FloatExt<F>>(expr: T1) -> T1 {
138 expr.abs()
139}
140
141#[function("abs(int256) -> int256")]
142pub fn int256_abs<TRef, T>(expr: TRef) -> Result<T>
143where
144 TRef: Into<T> + Debug,
145 T: IsNegative + CheckedNeg + Debug,
146{
147 let expr = expr.into();
148 if expr.is_negative() {
149 int256_neg(expr)
150 } else {
151 Ok(expr)
152 }
153}
154
155#[function("abs(decimal) -> decimal")]
156pub fn decimal_abs(decimal: Decimal) -> Decimal {
157 Decimal::abs(&decimal)
158}
159
160fn err_pow_zero_negative() -> ExprError {
161 ExprError::InvalidParam {
162 name: "rhs",
163 reason: "zero raised to a negative power is undefined".into(),
164 }
165}
166fn err_pow_negative_fract() -> ExprError {
167 ExprError::InvalidParam {
168 name: "rhs",
169 reason: "a negative number raised to a non-integer power yields a complex result".into(),
170 }
171}
172
173#[function("pow(float8, float8) -> float8")]
174pub fn pow_f64(l: F64, r: F64) -> Result<F64> {
175 if l.is_zero() && r.0 < 0.0 {
176 return Err(err_pow_zero_negative());
177 }
178 if l.0 < 0.0 && (r.is_finite() && !r.fract().is_zero()) {
179 return Err(err_pow_negative_fract());
180 }
181 let res = l.powf(r);
182 if res.is_infinite() && l.is_finite() && r.is_finite() {
183 return Err(ExprError::NumericOverflow);
184 }
185 if res.is_zero() && l.is_finite() && r.is_finite() && !l.is_zero() {
186 return Err(ExprError::NumericUnderflow);
187 }
188
189 Ok(res)
190}
191
192#[function("pow(decimal, decimal) -> decimal")]
193pub fn pow_decimal(l: Decimal, r: Decimal) -> Result<Decimal> {
194 use risingwave_common::types::DecimalPowError as PowError;
195
196 l.checked_powd(&r).map_err(|e| match e {
197 PowError::ZeroNegative => err_pow_zero_negative(),
198 PowError::NegativeFract => err_pow_negative_fract(),
199 PowError::Overflow => ExprError::NumericOverflow,
200 })
201}
202
203#[inline(always)]
204fn general_atm<T1, T2, T3, F>(l: T1, r: T2, atm: F) -> Result<T3>
205where
206 T1: Into<T3> + Debug,
207 T2: Into<T3> + Debug,
208 F: FnOnce(T3, T3) -> Result<T3>,
209{
210 atm(l.into(), r.into())
211}
212
213#[function("subtract(timestamp, timestamp) -> interval")]
214pub fn timestamp_timestamp_sub(l: Timestamp, r: Timestamp) -> Result<Interval> {
215 let tmp = l.0 - r.0; let days = tmp.num_days();
217 let usecs = (tmp - Duration::days(tmp.num_days()))
218 .num_microseconds()
219 .ok_or_else(|| ExprError::NumericOutOfRange)?;
220 Ok(Interval::from_month_day_usec(0, days as i32, usecs))
221}
222
223#[function("subtract(date, date) -> int4")]
224pub fn date_date_sub(l: Date, r: Date) -> Result<i32> {
225 Ok((l.0 - r.0).num_days() as i32) }
227
228#[function("add(interval, timestamp) -> timestamp")]
229pub fn interval_timestamp_add(l: Interval, r: Timestamp) -> Result<Timestamp> {
230 r.checked_add(l).ok_or(ExprError::NumericOutOfRange)
231}
232
233#[function("add(interval, date) -> timestamp")]
234pub fn interval_date_add(l: Interval, r: Date) -> Result<Timestamp> {
235 interval_timestamp_add(l, r.into())
236}
237
238#[function("add(interval, time) -> time")]
239pub fn interval_time_add(l: Interval, r: Time) -> Result<Time> {
240 time_interval_add(r, l)
241}
242
243#[function("add(date, interval) -> timestamp")]
244pub fn date_interval_add(l: Date, r: Interval) -> Result<Timestamp> {
245 interval_date_add(r, l)
246}
247
248#[function("subtract(date, interval) -> timestamp")]
249pub fn date_interval_sub(l: Date, r: Interval) -> Result<Timestamp> {
250 interval_date_add(r.checked_neg().ok_or(ExprError::NumericOutOfRange)?, l)
253}
254
255#[function("add(date, int4) -> date")]
256pub fn date_int_add(l: Date, r: i32) -> Result<Date> {
257 let date = l.0;
258 let date_wrapper = date
259 .checked_add_signed(chrono::Duration::days(r as i64))
260 .map(Date::new);
261
262 date_wrapper.ok_or(ExprError::NumericOutOfRange)
263}
264
265#[function("add(int4, date) -> date")]
266pub fn int_date_add(l: i32, r: Date) -> Result<Date> {
267 date_int_add(r, l)
268}
269
270#[function("subtract(date, int4) -> date")]
271pub fn date_int_sub(l: Date, r: i32) -> Result<Date> {
272 let date = l.0;
273 let date_wrapper = date
274 .checked_sub_signed(chrono::Duration::days(r as i64))
275 .map(Date::new);
276
277 date_wrapper.ok_or(ExprError::NumericOutOfRange)
278}
279
280#[function("add(timestamp, interval) -> timestamp")]
281pub fn timestamp_interval_add(l: Timestamp, r: Interval) -> Result<Timestamp> {
282 interval_timestamp_add(r, l)
283}
284
285#[function("subtract(timestamp, interval) -> timestamp")]
286pub fn timestamp_interval_sub(l: Timestamp, r: Interval) -> Result<Timestamp> {
287 interval_timestamp_add(r.checked_neg().ok_or(ExprError::NumericOutOfRange)?, l)
288}
289
290#[function("multiply(interval, *int) -> interval")]
291pub fn interval_int_mul(l: Interval, r: impl TryInto<i32> + Debug) -> Result<Interval> {
292 l.checked_mul_int(r).ok_or(ExprError::NumericOutOfRange)
293}
294
295#[function("multiply(*int, interval) -> interval")]
296pub fn int_interval_mul(l: impl TryInto<i32> + Debug, r: Interval) -> Result<Interval> {
297 interval_int_mul(r, l)
298}
299
300#[function("add(date, time) -> timestamp")]
301pub fn date_time_add(l: Date, r: Time) -> Result<Timestamp> {
302 Ok(Timestamp::new(NaiveDateTime::new(l.0, r.0)))
303}
304
305#[function("add(time, date) -> timestamp")]
306pub fn time_date_add(l: Time, r: Date) -> Result<Timestamp> {
307 date_time_add(r, l)
308}
309
310#[function("subtract(time, time) -> interval")]
311pub fn time_time_sub(l: Time, r: Time) -> Result<Interval> {
312 let tmp = l.0 - r.0; let usecs = tmp
314 .num_microseconds()
315 .ok_or_else(|| ExprError::NumericOutOfRange)?;
316 Ok(Interval::from_month_day_usec(0, 0, usecs))
317}
318
319#[function("subtract(time, interval) -> time")]
320pub fn time_interval_sub(l: Time, r: Interval) -> Result<Time> {
321 let time = l.0;
322 let (new_time, ignored) = time.overflowing_sub_signed(Duration::microseconds(r.usecs()));
323 if ignored == 0 {
324 Ok(Time::new(new_time))
325 } else {
326 Err(ExprError::NumericOutOfRange)
327 }
328}
329
330#[function("add(time, interval) -> time")]
331pub fn time_interval_add(l: Time, r: Interval) -> Result<Time> {
332 let time = l.0;
333 let (new_time, ignored) = time.overflowing_add_signed(Duration::microseconds(r.usecs()));
334 if ignored == 0 {
335 Ok(Time::new(new_time))
336 } else {
337 Err(ExprError::NumericOutOfRange)
338 }
339}
340
341#[function("divide(interval, *int) -> interval")]
342#[function("divide(interval, decimal) -> interval")]
343#[function("divide(interval, *float) -> interval")]
344pub fn interval_float_div<T2>(l: Interval, r: T2) -> Result<Interval>
345where
346 T2: TryInto<F64> + Debug,
347{
348 l.div_float(r).ok_or(ExprError::NumericOutOfRange)
349}
350
351#[function("multiply(interval, float4) -> interval")]
352#[function("multiply(interval, float8) -> interval")]
353#[function("multiply(interval, decimal) -> interval")]
354pub fn interval_float_mul<T2>(l: Interval, r: T2) -> Result<Interval>
355where
356 T2: TryInto<F64> + Debug,
357{
358 l.mul_float(r).ok_or(ExprError::NumericOutOfRange)
359}
360
361#[function("multiply(float4, interval) -> interval")]
362#[function("multiply(float8, interval) -> interval")]
363#[function("multiply(decimal, interval) -> interval")]
364pub fn float_interval_mul<T1>(l: T1, r: Interval) -> Result<Interval>
365where
366 T1: TryInto<F64> + Debug,
367{
368 r.mul_float(l).ok_or(ExprError::NumericOutOfRange)
369}
370
371#[function("sqrt(float8) -> float8")]
372pub fn sqrt_f64(expr: F64) -> Result<F64> {
373 if expr < F64::from(0.0) {
374 return Err(ExprError::InvalidParam {
375 name: "sqrt input",
376 reason: "input cannot be negative value".into(),
377 });
378 }
379 match expr.is_nan() || expr == f64::INFINITY || expr == -0.0 {
381 true => Ok(expr),
382 false => Ok(expr.sqrt()),
383 }
384}
385
386#[function("sqrt(decimal) -> decimal")]
387pub fn sqrt_decimal(expr: Decimal) -> Result<Decimal> {
388 match expr {
389 Decimal::NaN | Decimal::PositiveInf => Ok(expr),
390 Decimal::Normalized(value) => match value.sqrt() {
391 Some(res) => Ok(Decimal::from(res)),
392 None => Err(ExprError::InvalidParam {
393 name: "sqrt input",
394 reason: "input cannot be negative value".into(),
395 }),
396 },
397 Decimal::NegativeInf => Err(ExprError::InvalidParam {
398 name: "sqrt input",
399 reason: "input cannot be negative value".into(),
400 }),
401 }
402}
403
404#[function("cbrt(float8) -> float8")]
405pub fn cbrt_f64(expr: F64) -> F64 {
406 expr.cbrt()
407}
408
409#[function("sign(float8) -> float8")]
410pub fn sign_f64(input: F64) -> F64 {
411 match input.0.partial_cmp(&0.) {
412 Some(std::cmp::Ordering::Less) => (-1).into(),
413 Some(std::cmp::Ordering::Equal) => 0.into(),
414 Some(std::cmp::Ordering::Greater) => 1.into(),
415 None => 0.into(),
416 }
417}
418
419#[function("sign(decimal) -> decimal")]
420pub fn sign_dec(input: Decimal) -> Decimal {
421 input.sign()
422}
423
424#[function("scale(decimal) -> int4")]
425pub fn decimal_scale(d: Decimal) -> Option<i32> {
426 d.scale()
427}
428
429#[function("min_scale(decimal) -> int4")]
430pub fn decimal_min_scale(d: Decimal) -> Option<i32> {
431 d.normalize().scale()
432}
433
434#[function("trim_scale(decimal) -> decimal")]
435pub fn decimal_trim_scale(d: Decimal) -> Decimal {
436 d.normalize()
437}
438
439#[cfg(test)]
440mod tests {
441 use std::str::FromStr;
442
443 use num_traits::Float;
444 use risingwave_common::types::test_utils::IntervalTestExt;
445 use risingwave_common::types::{F32, Int256, Int256Ref, Scalar};
446
447 use super::*;
448
449 #[test]
450 fn test() {
451 assert_eq!(
452 general_add::<_, _, Decimal>(Decimal::from_str("1").unwrap(), 1i32).unwrap(),
453 Decimal::from_str("2").unwrap()
454 );
455 }
456
457 #[test]
458 fn test_arithmetic() {
459 assert_eq!(
460 general_add::<Decimal, i32, Decimal>(dec("1.0"), 1).unwrap(),
461 dec("2.0")
462 );
463 assert_eq!(
464 general_sub::<Decimal, i32, Decimal>(dec("1.0"), 2).unwrap(),
465 dec("-1.0")
466 );
467 assert_eq!(
468 general_mul::<Decimal, i32, Decimal>(dec("1.0"), 2).unwrap(),
469 dec("2.0")
470 );
471 assert_eq!(
472 general_div::<Decimal, i32, Decimal>(dec("2.0"), 2).unwrap(),
473 dec("1.0")
474 );
475 assert_eq!(
476 general_mod::<Decimal, i32, Decimal>(dec("2.0"), 2).unwrap(),
477 dec("0")
478 );
479 assert_eq!(general_neg::<Decimal>(dec("1.0")).unwrap(), dec("-1.0"));
480 assert_eq!(general_add::<i16, i32, i32>(1i16, 1i32).unwrap(), 2i32);
481 assert_eq!(general_sub::<i16, i32, i32>(1i16, 1i32).unwrap(), 0i32);
482 assert_eq!(general_mul::<i16, i32, i32>(1i16, 1i32).unwrap(), 1i32);
483 assert_eq!(general_div::<i16, i32, i32>(1i16, 1i32).unwrap(), 1i32);
484 assert_eq!(general_mod::<i16, i32, i32>(1i16, 1i32).unwrap(), 0i32);
485 assert_eq!(general_neg::<i16>(1i16).unwrap(), -1i16);
486
487 assert!(
488 general_add::<i32, F32, F64>(-1i32, 1f32.into())
489 .unwrap()
490 .is_zero()
491 );
492 assert!(
493 general_sub::<i32, F32, F64>(1i32, 1f32.into())
494 .unwrap()
495 .is_zero()
496 );
497 assert!(
498 general_mul::<i32, F32, F64>(0i32, 1f32.into())
499 .unwrap()
500 .is_zero()
501 );
502 assert!(
503 general_div::<i32, F32, F64>(0i32, 1f32.into())
504 .unwrap()
505 .is_zero()
506 );
507 assert_eq!(general_neg::<F32>(1f32.into()).unwrap(), F32::from(-1f32));
508 assert_eq!(
509 date_interval_add(Date::from_ymd_uncheck(1994, 1, 1), Interval::from_month(12))
510 .unwrap(),
511 Timestamp::new(
512 NaiveDateTime::parse_from_str("1995-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
513 )
514 );
515 assert_eq!(
516 interval_date_add(Interval::from_month(12), Date::from_ymd_uncheck(1994, 1, 1))
517 .unwrap(),
518 Timestamp::new(
519 NaiveDateTime::parse_from_str("1995-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
520 )
521 );
522 assert_eq!(
523 date_interval_sub(Date::from_ymd_uncheck(1994, 1, 1), Interval::from_month(12))
524 .unwrap(),
525 Timestamp::new(
526 NaiveDateTime::parse_from_str("1993-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
527 )
528 );
529 assert_eq!(sqrt_f64(F64::from(25.00)).unwrap(), F64::from(5.0));
530 assert_eq!(
531 sqrt_f64(F64::from(107)).unwrap(),
532 F64::from(10.344080432788601)
533 );
534 assert_eq!(
535 sqrt_f64(F64::from(12.234567)).unwrap(),
536 F64::from(3.4977945908815173)
537 );
538 assert!(sqrt_f64(F64::from(-25.00)).is_err());
539 assert_eq!(sqrt_f64(F64::from(f64::NAN)).unwrap(), F64::from(f64::NAN));
541 assert_eq!(
542 sqrt_f64(F64::from(f64::neg_zero())).unwrap(),
543 F64::from(f64::neg_zero())
544 );
545 assert_eq!(
546 sqrt_f64(F64::from(f64::INFINITY)).unwrap(),
547 F64::from(f64::INFINITY)
548 );
549 assert!(sqrt_f64(F64::from(f64::NEG_INFINITY)).is_err());
550 assert_eq!(sqrt_decimal(dec("25.0")).unwrap(), dec("5.0"));
551 assert_eq!(
552 sqrt_decimal(dec("107")).unwrap(),
553 dec("10.344080432788600469738599442")
554 );
555 assert_eq!(
556 sqrt_decimal(dec("12.234567")).unwrap(),
557 dec("3.4977945908815171589625746860")
558 );
559 assert!(sqrt_decimal(dec("-25.0")).is_err());
560 assert_eq!(sqrt_decimal(dec("nan")).unwrap(), dec("nan"));
561 assert_eq!(sqrt_decimal(dec("inf")).unwrap(), dec("inf"));
562 assert_eq!(sqrt_decimal(dec("-0")).unwrap(), dec("-0"));
563 assert!(sqrt_decimal(dec("-inf")).is_err());
564 }
565
566 #[test]
567 fn test_arithmetic_int256() {
568 let tuples = vec![
569 (0, 1, "0"),
570 (0, -1, "0"),
571 (1, 1, "1"),
572 (1, -1, "-1"),
573 (1, 2, "0.5"),
574 (1, -2, "-0.5"),
575 (9007199254740991i64, 2, "4503599627370495.5"),
576 ];
577
578 for (i, j, k) in tuples {
579 let lhs = Int256::from(i);
580 let rhs = F64::from(j);
581 let res = F64::from_str(k).unwrap();
582 assert_eq!(
583 general_div::<Int256Ref<'_>, F64, F64>(lhs.as_scalar_ref(), rhs).unwrap(),
584 res,
585 );
586 }
587 }
588
589 fn dec(s: &str) -> Decimal {
590 Decimal::from_str(s).unwrap()
591 }
592}