1use std::fmt::Debug;
16use std::io::{Cursor, Read, Write};
17use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
18
19use byteorder::{BigEndian, ReadBytesExt};
20use bytes::{BufMut, BytesMut};
21use num_traits::{
22 CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Zero,
23};
24use postgres_types::{FromSql, IsNull, ToSql, Type, accepts, to_sql_checked};
25use risingwave_common_estimate_size::ZeroHeapSize;
26use rust_decimal::prelude::FromStr;
27use rust_decimal::{Decimal as RustDecimal, Error, MathematicalOps as _, RoundingStrategy};
28
29use super::DataType;
30use super::to_text::ToText;
31use crate::array::ArrayResult;
32use crate::types::Decimal::Normalized;
33use crate::types::ordered_float::OrderedFloat;
34
35#[derive(Debug, Copy, parse_display::Display, Clone, PartialEq, Hash, Eq, Ord, PartialOrd)]
36pub enum Decimal {
37 #[display("-Infinity")]
38 NegativeInf,
39 #[display("{0}")]
40 Normalized(RustDecimal),
41 #[display("Infinity")]
42 PositiveInf,
43 #[display("NaN")]
44 NaN,
45}
46
47impl ZeroHeapSize for Decimal {}
48
49impl ToText for Decimal {
50 fn write<W: std::fmt::Write>(&self, f: &mut W) -> std::fmt::Result {
51 write!(f, "{self}")
52 }
53
54 fn write_with_type<W: std::fmt::Write>(&self, ty: &DataType, f: &mut W) -> std::fmt::Result {
55 match ty {
56 DataType::Decimal => self.write(f),
57 _ => unreachable!(),
58 }
59 }
60}
61
62impl Decimal {
63 pub fn to_protobuf(self, output: &mut impl Write) -> ArrayResult<usize> {
65 let buf = self.unordered_serialize();
66 output.write_all(&buf)?;
67 Ok(buf.len())
68 }
69
70 pub fn from_protobuf(input: &mut impl Read) -> ArrayResult<Self> {
72 let mut buf = [0u8; 16];
73 input.read_exact(&mut buf)?;
74 Ok(Self::unordered_deserialize(buf))
75 }
76
77 pub fn from_scientific(value: &str) -> Option<Self> {
78 let decimal = RustDecimal::from_scientific(value).ok()?;
79 Some(Normalized(decimal))
80 }
81
82 pub fn from_str_radix(s: &str, radix: u32) -> rust_decimal::Result<Self> {
83 match s.to_ascii_lowercase().as_str() {
84 "nan" => Ok(Decimal::NaN),
85 "inf" | "+inf" | "infinity" | "+infinity" => Ok(Decimal::PositiveInf),
86 "-inf" | "-infinity" => Ok(Decimal::NegativeInf),
87 s => RustDecimal::from_str_radix(s, radix).map(Decimal::Normalized),
88 }
89 }
90}
91
92impl ToSql for Decimal {
93 accepts!(NUMERIC);
94
95 to_sql_checked!();
96
97 fn to_sql(
98 &self,
99 ty: &Type,
100 out: &mut BytesMut,
101 ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
102 where
103 Self: Sized,
104 {
105 match self {
106 Decimal::Normalized(d) => {
107 return d.to_sql(ty, out);
108 }
109 Decimal::NaN => {
110 out.reserve(8);
111 out.put_u16(0);
112 out.put_i16(0);
113 out.put_u16(0xC000);
114 out.put_i16(0);
115 }
116 Decimal::PositiveInf => {
117 out.reserve(8);
118 out.put_u16(0);
119 out.put_i16(0);
120 out.put_u16(0xD000);
121 out.put_i16(0);
122 }
123 Decimal::NegativeInf => {
124 out.reserve(8);
125 out.put_u16(0);
126 out.put_i16(0);
127 out.put_u16(0xF000);
128 out.put_i16(0);
129 }
130 }
131 Ok(IsNull::No)
132 }
133}
134
135impl<'a> FromSql<'a> for Decimal {
136 fn from_sql(
137 ty: &Type,
138 raw: &'a [u8],
139 ) -> Result<Self, Box<dyn std::error::Error + 'static + Sync + Send>> {
140 let mut rdr = Cursor::new(raw);
141 let _n_digits = rdr.read_u16::<BigEndian>()?;
142 let _weight = rdr.read_i16::<BigEndian>()?;
143 let sign = rdr.read_u16::<BigEndian>()?;
144 match sign {
145 0xC000 => Ok(Self::NaN),
146 0xD000 => Ok(Self::PositiveInf),
147 0xF000 => Ok(Self::NegativeInf),
148 _ => RustDecimal::from_sql(ty, raw).map(Self::Normalized),
149 }
150 }
151
152 fn accepts(ty: &Type) -> bool {
153 matches!(*ty, Type::NUMERIC)
154 }
155}
156
157macro_rules! impl_convert_int {
158 ($T:ty) => {
159 impl core::convert::From<$T> for Decimal {
160 #[inline]
161 fn from(t: $T) -> Self {
162 Self::Normalized(t.into())
163 }
164 }
165
166 impl core::convert::TryFrom<Decimal> for $T {
167 type Error = Error;
168
169 #[inline]
170 fn try_from(d: Decimal) -> Result<Self, Self::Error> {
171 match d.round_dp_ties_away(0) {
172 Decimal::Normalized(d) => d.try_into(),
173 _ => Err(Error::ConversionTo(std::any::type_name::<$T>().into())),
174 }
175 }
176 }
177 };
178}
179
180macro_rules! impl_convert_float {
181 ($T:ty) => {
182 impl core::convert::TryFrom<$T> for Decimal {
183 type Error = Error;
184
185 fn try_from(num: $T) -> Result<Self, Self::Error> {
186 match num {
187 num if num.is_nan() => Ok(Decimal::NaN),
188 num if num.is_infinite() && num.is_sign_positive() => Ok(Decimal::PositiveInf),
189 num if num.is_infinite() && num.is_sign_negative() => Ok(Decimal::NegativeInf),
190 num => num.try_into().map(Decimal::Normalized),
191 }
192 }
193 }
194 impl core::convert::TryFrom<OrderedFloat<$T>> for Decimal {
195 type Error = Error;
196
197 fn try_from(value: OrderedFloat<$T>) -> Result<Self, Self::Error> {
198 value.0.try_into()
199 }
200 }
201
202 impl core::convert::TryFrom<Decimal> for $T {
203 type Error = Error;
204
205 fn try_from(d: Decimal) -> Result<Self, Self::Error> {
206 match d {
207 Decimal::Normalized(d) => d.try_into(),
208 Decimal::NaN => Ok(<$T>::NAN),
209 Decimal::PositiveInf => Ok(<$T>::INFINITY),
210 Decimal::NegativeInf => Ok(<$T>::NEG_INFINITY),
211 }
212 }
213 }
214 impl core::convert::TryFrom<Decimal> for OrderedFloat<$T> {
215 type Error = Error;
216
217 fn try_from(d: Decimal) -> Result<Self, Self::Error> {
218 d.try_into().map(Self)
219 }
220 }
221 };
222}
223
224macro_rules! checked_proxy {
225 ($trait:ty, $func:ident, $op: tt) => {
226 impl $trait for Decimal {
227 fn $func(&self, other: &Self) -> Option<Self> {
228 match (self, other) {
229 (Self::Normalized(lhs), Self::Normalized(rhs)) => {
230 lhs.$func(rhs).map(Decimal::Normalized)
231 }
232 (lhs, rhs) => Some(*lhs $op *rhs),
233 }
234 }
235 }
236 }
237}
238
239impl_convert_float!(f32);
240impl_convert_float!(f64);
241
242impl_convert_int!(isize);
243impl_convert_int!(i8);
244impl_convert_int!(i16);
245impl_convert_int!(i32);
246impl_convert_int!(i64);
247impl_convert_int!(usize);
248impl_convert_int!(u8);
249impl_convert_int!(u16);
250impl_convert_int!(u32);
251impl_convert_int!(u64);
252
253checked_proxy!(CheckedRem, checked_rem, %);
254checked_proxy!(CheckedSub, checked_sub, -);
255checked_proxy!(CheckedAdd, checked_add, +);
256checked_proxy!(CheckedDiv, checked_div, /);
257checked_proxy!(CheckedMul, checked_mul, *);
258
259impl Add for Decimal {
260 type Output = Self;
261
262 fn add(self, other: Self) -> Self {
263 match (self, other) {
264 (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs + rhs),
265 (Self::NaN, _) => Self::NaN,
266 (_, Self::NaN) => Self::NaN,
267 (Self::PositiveInf, Self::NegativeInf) => Self::NaN,
268 (Self::NegativeInf, Self::PositiveInf) => Self::NaN,
269 (Self::PositiveInf, _) => Self::PositiveInf,
270 (_, Self::PositiveInf) => Self::PositiveInf,
271 (Self::NegativeInf, _) => Self::NegativeInf,
272 (_, Self::NegativeInf) => Self::NegativeInf,
273 }
274 }
275}
276
277impl Neg for Decimal {
278 type Output = Self;
279
280 fn neg(self) -> Self {
281 match self {
282 Self::Normalized(d) => Self::Normalized(-d),
283 Self::NaN => Self::NaN,
284 Self::PositiveInf => Self::NegativeInf,
285 Self::NegativeInf => Self::PositiveInf,
286 }
287 }
288}
289
290impl CheckedNeg for Decimal {
291 fn checked_neg(&self) -> Option<Self> {
292 match self {
293 Self::Normalized(d) => Some(Self::Normalized(-d)),
294 Self::NaN => Some(Self::NaN),
295 Self::PositiveInf => Some(Self::NegativeInf),
296 Self::NegativeInf => Some(Self::PositiveInf),
297 }
298 }
299}
300
301impl Rem for Decimal {
302 type Output = Self;
303
304 fn rem(self, other: Self) -> Self {
305 match (self, other) {
306 (Self::Normalized(lhs), Self::Normalized(rhs)) if !rhs.is_zero() => {
307 Self::Normalized(lhs % rhs)
308 }
309 (Self::Normalized(_), Self::Normalized(_)) => Self::NaN,
310 (Self::Normalized(lhs), Self::PositiveInf)
311 if lhs.is_sign_positive() || lhs.is_zero() =>
312 {
313 Self::Normalized(lhs)
314 }
315 (Self::Normalized(d), Self::PositiveInf) => Self::Normalized(d),
316 (Self::Normalized(lhs), Self::NegativeInf)
317 if lhs.is_sign_negative() || lhs.is_zero() =>
318 {
319 Self::Normalized(lhs)
320 }
321 (Self::Normalized(d), Self::NegativeInf) => Self::Normalized(d),
322 _ => Self::NaN,
323 }
324 }
325}
326
327impl Div for Decimal {
328 type Output = Self;
329
330 fn div(self, other: Self) -> Self {
331 match (self, other) {
332 (Self::NaN, _) => Self::NaN,
334 (_, Self::NaN) => Self::NaN,
335 (lhs, Self::Normalized(rhs)) if rhs.is_zero() => match lhs {
337 Self::Normalized(lhs) => {
338 if lhs.is_sign_positive() && !lhs.is_zero() {
339 Self::PositiveInf
340 } else if lhs.is_sign_negative() && !lhs.is_zero() {
341 Self::NegativeInf
342 } else {
343 Self::NaN
344 }
345 }
346 Self::PositiveInf => Self::PositiveInf,
347 Self::NegativeInf => Self::NegativeInf,
348 _ => unreachable!(),
349 },
350 (Self::Normalized(_), Self::PositiveInf) => Self::Normalized(RustDecimal::from(0)),
352 (_, Self::PositiveInf) => Self::NaN,
353 (Self::Normalized(_), Self::NegativeInf) => Self::Normalized(RustDecimal::from(0)),
354 (_, Self::NegativeInf) => Self::NaN,
355 (Self::PositiveInf, Self::Normalized(d)) if d.is_sign_positive() => Self::PositiveInf,
357 (Self::PositiveInf, Self::Normalized(d)) if d.is_sign_negative() => Self::NegativeInf,
358 (Self::NegativeInf, Self::Normalized(d)) if d.is_sign_positive() => Self::NegativeInf,
359 (Self::NegativeInf, Self::Normalized(d)) if d.is_sign_negative() => Self::PositiveInf,
360 (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs / rhs),
362 _ => unreachable!(),
363 }
364 }
365}
366
367impl Mul for Decimal {
368 type Output = Self;
369
370 fn mul(self, other: Self) -> Self {
371 match (self, other) {
372 (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs * rhs),
373 (Self::NaN, _) => Self::NaN,
374 (_, Self::NaN) => Self::NaN,
375 (Self::PositiveInf, Self::Normalized(rhs))
376 if !rhs.is_zero() && rhs.is_sign_negative() =>
377 {
378 Self::NegativeInf
379 }
380 (Self::PositiveInf, Self::Normalized(rhs))
381 if !rhs.is_zero() && rhs.is_sign_positive() =>
382 {
383 Self::PositiveInf
384 }
385 (Self::PositiveInf, Self::PositiveInf) => Self::PositiveInf,
386 (Self::PositiveInf, Self::NegativeInf) => Self::NegativeInf,
387 (Self::Normalized(lhs), Self::PositiveInf)
388 if !lhs.is_zero() && lhs.is_sign_negative() =>
389 {
390 Self::NegativeInf
391 }
392 (Self::Normalized(lhs), Self::PositiveInf)
393 if !lhs.is_zero() && lhs.is_sign_positive() =>
394 {
395 Self::PositiveInf
396 }
397 (Self::NegativeInf, Self::PositiveInf) => Self::NegativeInf,
398 (Self::NegativeInf, Self::Normalized(rhs))
399 if !rhs.is_zero() && rhs.is_sign_negative() =>
400 {
401 Self::PositiveInf
402 }
403 (Self::NegativeInf, Self::Normalized(rhs))
404 if !rhs.is_zero() && rhs.is_sign_positive() =>
405 {
406 Self::NegativeInf
407 }
408 (Self::NegativeInf, Self::NegativeInf) => Self::PositiveInf,
409 (Self::Normalized(lhs), Self::NegativeInf)
410 if !lhs.is_zero() && lhs.is_sign_negative() =>
411 {
412 Self::PositiveInf
413 }
414 (Self::Normalized(lhs), Self::NegativeInf)
415 if !lhs.is_zero() && lhs.is_sign_positive() =>
416 {
417 Self::NegativeInf
418 }
419 _ => Self::NaN,
421 }
422 }
423}
424
425impl Sub for Decimal {
426 type Output = Self;
427
428 fn sub(self, other: Self) -> Self {
429 match (self, other) {
430 (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs - rhs),
431 (Self::NaN, _) => Self::NaN,
432 (_, Self::NaN) => Self::NaN,
433 (Self::PositiveInf, Self::PositiveInf) => Self::NaN,
434 (Self::NegativeInf, Self::NegativeInf) => Self::NaN,
435 (Self::PositiveInf, _) => Self::PositiveInf,
436 (_, Self::PositiveInf) => Self::NegativeInf,
437 (Self::NegativeInf, _) => Self::NegativeInf,
438 (_, Self::NegativeInf) => Self::PositiveInf,
439 }
440 }
441}
442
443impl Decimal {
444 const MAX_I128_REPR: i128 = 0x0000_0000_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF;
445 pub const MAX_PRECISION: u8 = 28;
446
447 pub fn scale(&self) -> Option<i32> {
448 let Decimal::Normalized(d) = self else {
449 return None;
450 };
451 Some(d.scale() as _)
452 }
453
454 pub fn rescale(&mut self, scale: u32) {
455 if let Normalized(a) = self {
456 a.rescale(scale);
457 }
458 }
459
460 #[must_use]
461 pub fn round_dp_ties_away(&self, dp: u32) -> Self {
462 match self {
463 Self::Normalized(d) => {
464 let new_d = d.round_dp_with_strategy(dp, RoundingStrategy::MidpointAwayFromZero);
465 Self::Normalized(new_d)
466 }
467 d => *d,
468 }
469 }
470
471 #[must_use]
473 pub fn round_left_ties_away(&self, left: u32) -> Option<Self> {
474 let &Self::Normalized(mut d) = self else {
475 return Some(*self);
476 };
477
478 let old_scale = d.scale();
481 let new_scale = old_scale.saturating_add(left);
482 const MANTISSA_UP: i128 = 5 * 10i128.pow(Decimal::MAX_PRECISION as _);
483 let d = match new_scale.cmp(&Self::MAX_PRECISION.add(1).into()) {
484 std::cmp::Ordering::Less => {
486 d.set_scale(new_scale).unwrap();
487 d.round_dp_with_strategy(0, RoundingStrategy::MidpointAwayFromZero)
488 }
489 std::cmp::Ordering::Equal => (d.mantissa() / MANTISSA_UP).signum().into(),
491 std::cmp::Ordering::Greater => 0.into(),
493 };
494
495 match left > Decimal::MAX_PRECISION.into() {
498 true => d.is_zero().then(|| 0.into()),
499 false => d
500 .checked_mul(RustDecimal::from_i128_with_scale(10i128.pow(left), 0))
501 .map(Self::Normalized),
502 }
503 }
504
505 #[must_use]
506 pub fn ceil(&self) -> Self {
507 match self {
508 Self::Normalized(d) => {
509 let mut d = d.ceil();
510 if d.is_zero() {
511 d.set_sign_positive(true);
512 }
513 Self::Normalized(d)
514 }
515 d => *d,
516 }
517 }
518
519 #[must_use]
520 pub fn floor(&self) -> Self {
521 match self {
522 Self::Normalized(d) => Self::Normalized(d.floor()),
523 d => *d,
524 }
525 }
526
527 #[must_use]
528 pub fn trunc(&self) -> Self {
529 match self {
530 Self::Normalized(d) => {
531 let mut d = d.trunc();
532 if d.is_zero() {
533 d.set_sign_positive(true);
534 }
535 Self::Normalized(d)
536 }
537 d => *d,
538 }
539 }
540
541 #[must_use]
542 pub fn round_ties_even(&self) -> Self {
543 match self {
544 Self::Normalized(d) => Self::Normalized(d.round()),
545 d => *d,
546 }
547 }
548
549 pub fn from_i128_with_scale(num: i128, scale: u32) -> Self {
550 Decimal::Normalized(RustDecimal::from_i128_with_scale(num, scale))
551 }
552
553 pub fn truncated_i128_and_scale(mut num: i128, mut scale: u32) -> Option<Self> {
556 if num.abs() > Self::MAX_I128_REPR {
557 let digits = num.abs().ilog10() + 1;
558 let diff_scale = digits.saturating_sub(Self::MAX_PRECISION as u32);
559 if scale < diff_scale {
560 return None;
561 }
562 num /= 10i128.pow(diff_scale);
563 scale -= diff_scale;
564 }
565 if scale > Self::MAX_PRECISION as u32 {
566 let diff_scale = scale - Self::MAX_PRECISION as u32;
567 num /= 10i128.pow(diff_scale);
568 scale = Self::MAX_PRECISION as u32;
569 }
570 Some(Decimal::Normalized(
571 RustDecimal::try_from_i128_with_scale(num, scale).ok()?,
572 ))
573 }
574
575 #[must_use]
576 pub fn normalize(&self) -> Self {
577 match self {
578 Self::Normalized(d) => Self::Normalized(d.normalize()),
579 d => *d,
580 }
581 }
582
583 pub fn unordered_serialize(&self) -> [u8; 16] {
584 match self {
587 Self::Normalized(d) => d.serialize(),
588 Self::NaN => [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
589 Self::PositiveInf => [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
590 Self::NegativeInf => [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
591 }
592 }
593
594 pub fn unordered_deserialize(bytes: [u8; 16]) -> Self {
595 match bytes[0] {
596 0u8 => Self::Normalized(RustDecimal::deserialize(bytes)),
597 1u8 => Self::NaN,
598 2u8 => Self::PositiveInf,
599 3u8 => Self::NegativeInf,
600 _ => unreachable!(),
601 }
602 }
603
604 pub fn abs(&self) -> Self {
605 match self {
606 Self::Normalized(d) => {
607 if d.is_sign_negative() {
608 Self::Normalized(-d)
609 } else {
610 Self::Normalized(*d)
611 }
612 }
613 Self::NaN => Self::NaN,
614 Self::PositiveInf => Self::PositiveInf,
615 Self::NegativeInf => Self::PositiveInf,
616 }
617 }
618
619 pub fn sign(&self) -> Self {
620 match self {
621 Self::NaN => Self::NaN,
622 _ => match self.cmp(&0.into()) {
623 std::cmp::Ordering::Less => (-1).into(),
624 std::cmp::Ordering::Equal => 0.into(),
625 std::cmp::Ordering::Greater => 1.into(),
626 },
627 }
628 }
629
630 pub fn checked_exp(&self) -> Option<Decimal> {
631 match self {
632 Self::Normalized(d) => d.checked_exp().map(Self::Normalized),
633 Self::NaN => Some(Self::NaN),
634 Self::PositiveInf => Some(Self::PositiveInf),
635 Self::NegativeInf => Some(Self::zero()),
636 }
637 }
638
639 pub fn checked_ln(&self) -> Option<Decimal> {
640 match self {
641 Self::Normalized(d) => d.checked_ln().map(Self::Normalized),
642 Self::NaN => Some(Self::NaN),
643 Self::PositiveInf => Some(Self::PositiveInf),
644 Self::NegativeInf => None,
645 }
646 }
647
648 pub fn checked_log10(&self) -> Option<Decimal> {
649 match self {
650 Self::Normalized(d) => d.checked_log10().map(Self::Normalized),
651 Self::NaN => Some(Self::NaN),
652 Self::PositiveInf => Some(Self::PositiveInf),
653 Self::NegativeInf => None,
654 }
655 }
656
657 pub fn checked_powd(&self, rhs: &Self) -> Result<Self, PowError> {
658 use std::cmp::Ordering;
659
660 match (self, rhs) {
661 (Decimal::NaN, Decimal::NaN)
663 | (Decimal::PositiveInf, Decimal::NaN)
664 | (Decimal::NegativeInf, Decimal::NaN)
665 | (Decimal::NaN, Decimal::PositiveInf)
666 | (Decimal::NaN, Decimal::NegativeInf) => Ok(Self::NaN),
667 (Normalized(lhs), Decimal::NaN) => match lhs.is_one() {
668 true => Ok(1.into()),
669 false => Ok(Self::NaN),
670 },
671 (Decimal::NaN, Normalized(rhs)) => match rhs.is_zero() {
672 true => Ok(1.into()),
673 false => Ok(Self::NaN),
674 },
675
676 (Normalized(lhs), Decimal::PositiveInf) => match lhs.abs().cmp(&1.into()) {
678 Ordering::Greater => Ok(Self::PositiveInf),
679 Ordering::Equal => Ok(1.into()),
680 Ordering::Less => Ok(0.into()),
681 },
682 (Decimal::PositiveInf, Decimal::PositiveInf)
685 | (Decimal::NegativeInf, Decimal::PositiveInf) => Ok(Self::PositiveInf),
686
687 (Normalized(lhs), Decimal::NegativeInf) => match lhs.abs().cmp(&1.into()) {
689 Ordering::Greater => Ok(0.into()),
690 Ordering::Equal => Ok(1.into()),
691 Ordering::Less => match lhs.is_zero() {
692 true => Err(PowError::ZeroNegative),
694 false => Ok(Self::PositiveInf),
695 },
696 },
697 (Decimal::PositiveInf, Decimal::NegativeInf)
698 | (Decimal::NegativeInf, Decimal::NegativeInf) => Ok(0.into()),
699
700 (Decimal::PositiveInf, Normalized(rhs)) => match rhs.cmp(&0.into()) {
702 Ordering::Greater => Ok(Self::PositiveInf),
703 Ordering::Equal => Ok(1.into()),
704 Ordering::Less => Ok(0.into()),
705 },
706
707 (Decimal::NegativeInf, Normalized(rhs)) => match !rhs.fract().is_zero() {
709 true => Err(PowError::NegativeFract),
711 false => match (rhs.cmp(&0.into()), rhs.rem(&2.into()).abs().is_one()) {
712 (Ordering::Greater, true) => Ok(Self::NegativeInf),
713 (Ordering::Greater, false) => Ok(Self::PositiveInf),
714 (Ordering::Equal, true) => unreachable!(),
715 (Ordering::Equal, false) => Ok(1.into()),
716 (Ordering::Less, true) => Ok(0.into()), (Ordering::Less, false) => Ok(0.into()),
718 },
719 },
720
721 (Normalized(lhs), Normalized(rhs)) => {
723 if lhs.is_zero() && rhs < &0.into() {
724 return Err(PowError::ZeroNegative);
725 }
726 if lhs < &0.into() && !rhs.fract().is_zero() {
727 return Err(PowError::NegativeFract);
728 }
729 match lhs.checked_powd(*rhs) {
730 Some(d) => Ok(Self::Normalized(d)),
731 None => Err(PowError::Overflow),
732 }
733 }
734 }
735 }
736}
737
738pub enum PowError {
739 ZeroNegative,
740 NegativeFract,
741 Overflow,
742}
743
744impl From<Decimal> for memcomparable::Decimal {
745 fn from(d: Decimal) -> Self {
746 match d {
747 Decimal::Normalized(d) => Self::Normalized(d),
748 Decimal::PositiveInf => Self::Inf,
749 Decimal::NegativeInf => Self::NegInf,
750 Decimal::NaN => Self::NaN,
751 }
752 }
753}
754
755impl From<memcomparable::Decimal> for Decimal {
756 fn from(d: memcomparable::Decimal) -> Self {
757 match d {
758 memcomparable::Decimal::Normalized(d) => Self::Normalized(d),
759 memcomparable::Decimal::Inf => Self::PositiveInf,
760 memcomparable::Decimal::NegInf => Self::NegativeInf,
761 memcomparable::Decimal::NaN => Self::NaN,
762 }
763 }
764}
765
766impl Default for Decimal {
767 fn default() -> Self {
768 Self::Normalized(RustDecimal::default())
769 }
770}
771
772impl FromStr for Decimal {
773 type Err = Error;
774
775 fn from_str(s: &str) -> Result<Self, Self::Err> {
776 match s.to_ascii_lowercase().as_str() {
777 "nan" => Ok(Decimal::NaN),
778 "inf" | "+inf" | "infinity" | "+infinity" => Ok(Decimal::PositiveInf),
779 "-inf" | "-infinity" => Ok(Decimal::NegativeInf),
780 s => RustDecimal::from_str(s)
781 .or_else(|_| RustDecimal::from_scientific(s))
782 .map(Decimal::Normalized),
783 }
784 }
785}
786
787impl Zero for Decimal {
788 fn zero() -> Self {
789 Self::Normalized(RustDecimal::zero())
790 }
791
792 fn is_zero(&self) -> bool {
793 if let Self::Normalized(d) = self {
794 d.is_zero()
795 } else {
796 false
797 }
798 }
799}
800
801impl One for Decimal {
802 fn one() -> Self {
803 Self::Normalized(RustDecimal::one())
804 }
805}
806
807impl Num for Decimal {
808 type FromStrRadixErr = Error;
809
810 fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
811 if str.eq_ignore_ascii_case("inf") || str.eq_ignore_ascii_case("infinity") {
812 Ok(Self::PositiveInf)
813 } else if str.eq_ignore_ascii_case("-inf") || str.eq_ignore_ascii_case("-infinity") {
814 Ok(Self::NegativeInf)
815 } else if str.eq_ignore_ascii_case("nan") {
816 Ok(Self::NaN)
817 } else {
818 RustDecimal::from_str_radix(str, radix).map(Decimal::Normalized)
819 }
820 }
821}
822
823impl From<RustDecimal> for Decimal {
824 fn from(d: RustDecimal) -> Self {
825 Self::Normalized(d)
826 }
827}
828
829#[cfg(test)]
830mod tests {
831 use itertools::Itertools as _;
832 use risingwave_common_estimate_size::EstimateSize;
833
834 use super::*;
835 use crate::util::iter_util::ZipEqFast;
836
837 fn check(lhs: f32, rhs: f32) -> bool {
838 if lhs.is_nan() && rhs.is_nan() {
839 true
840 } else if lhs.is_infinite() && rhs.is_infinite() {
841 if lhs.is_sign_positive() && rhs.is_sign_positive() {
842 true
843 } else {
844 lhs.is_sign_negative() && rhs.is_sign_negative()
845 }
846 } else if lhs.is_finite() && rhs.is_finite() {
847 lhs == rhs
848 } else {
849 false
850 }
851 }
852
853 #[test]
854 fn check_op_with_float() {
855 let decimals = [
856 Decimal::NaN,
857 Decimal::PositiveInf,
858 Decimal::NegativeInf,
859 Decimal::try_from(1.0).unwrap(),
860 Decimal::try_from(-1.0).unwrap(),
861 Decimal::try_from(0.0).unwrap(),
862 ];
863 let floats = [
864 f32::NAN,
865 f32::INFINITY,
866 f32::NEG_INFINITY,
867 1.0f32,
868 -1.0f32,
869 0.0f32,
870 ];
871 for (d_lhs, f_lhs) in decimals.iter().zip_eq_fast(floats.iter()) {
872 for (d_rhs, f_rhs) in decimals.iter().zip_eq_fast(floats.iter()) {
873 assert!(check((*d_lhs + *d_rhs).try_into().unwrap(), f_lhs + f_rhs));
874 assert!(check((*d_lhs - *d_rhs).try_into().unwrap(), f_lhs - f_rhs));
875 assert!(check((*d_lhs * *d_rhs).try_into().unwrap(), f_lhs * f_rhs));
876 assert!(check((*d_lhs / *d_rhs).try_into().unwrap(), f_lhs / f_rhs));
877 assert!(check((*d_lhs % *d_rhs).try_into().unwrap(), f_lhs % f_rhs));
878 }
879 }
880 }
881
882 #[test]
883 fn basic_test() {
884 assert_eq!(Decimal::from_str("nan").unwrap(), Decimal::NaN,);
885 assert_eq!(Decimal::from_str("NaN").unwrap(), Decimal::NaN,);
886 assert_eq!(Decimal::from_str("NAN").unwrap(), Decimal::NaN,);
887 assert_eq!(Decimal::from_str("nAn").unwrap(), Decimal::NaN,);
888 assert_eq!(Decimal::from_str("nAN").unwrap(), Decimal::NaN,);
889 assert_eq!(Decimal::from_str("Nan").unwrap(), Decimal::NaN,);
890 assert_eq!(Decimal::from_str("NAn").unwrap(), Decimal::NaN,);
891
892 assert_eq!(Decimal::from_str("inf").unwrap(), Decimal::PositiveInf,);
893 assert_eq!(Decimal::from_str("INF").unwrap(), Decimal::PositiveInf,);
894 assert_eq!(Decimal::from_str("iNF").unwrap(), Decimal::PositiveInf,);
895 assert_eq!(Decimal::from_str("inF").unwrap(), Decimal::PositiveInf,);
896 assert_eq!(Decimal::from_str("InF").unwrap(), Decimal::PositiveInf,);
897 assert_eq!(Decimal::from_str("INf").unwrap(), Decimal::PositiveInf,);
898 assert_eq!(Decimal::from_str("+inf").unwrap(), Decimal::PositiveInf,);
899 assert_eq!(Decimal::from_str("+INF").unwrap(), Decimal::PositiveInf,);
900 assert_eq!(Decimal::from_str("+Inf").unwrap(), Decimal::PositiveInf,);
901 assert_eq!(Decimal::from_str("+iNF").unwrap(), Decimal::PositiveInf,);
902 assert_eq!(Decimal::from_str("+inF").unwrap(), Decimal::PositiveInf,);
903 assert_eq!(Decimal::from_str("+InF").unwrap(), Decimal::PositiveInf,);
904 assert_eq!(Decimal::from_str("+INf").unwrap(), Decimal::PositiveInf,);
905 assert_eq!(Decimal::from_str("inFINity").unwrap(), Decimal::PositiveInf,);
906 assert_eq!(
907 Decimal::from_str("+infiNIty").unwrap(),
908 Decimal::PositiveInf,
909 );
910
911 assert_eq!(Decimal::from_str("-inf").unwrap(), Decimal::NegativeInf,);
912 assert_eq!(Decimal::from_str("-INF").unwrap(), Decimal::NegativeInf,);
913 assert_eq!(Decimal::from_str("-Inf").unwrap(), Decimal::NegativeInf,);
914 assert_eq!(Decimal::from_str("-iNF").unwrap(), Decimal::NegativeInf,);
915 assert_eq!(Decimal::from_str("-inF").unwrap(), Decimal::NegativeInf,);
916 assert_eq!(Decimal::from_str("-InF").unwrap(), Decimal::NegativeInf,);
917 assert_eq!(Decimal::from_str("-INf").unwrap(), Decimal::NegativeInf,);
918 assert_eq!(
919 Decimal::from_str("-INfinity").unwrap(),
920 Decimal::NegativeInf,
921 );
922
923 assert_eq!(
924 Decimal::try_from(10.0).unwrap() / Decimal::PositiveInf,
925 Decimal::try_from(0.0).unwrap(),
926 );
927 assert_eq!(
928 Decimal::try_from(f32::INFINITY).unwrap(),
929 Decimal::PositiveInf
930 );
931 assert_eq!(Decimal::try_from(f64::NAN).unwrap(), Decimal::NaN);
932 assert_eq!(
933 Decimal::try_from(f64::INFINITY).unwrap(),
934 Decimal::PositiveInf
935 );
936 assert_eq!(
937 Decimal::unordered_deserialize(Decimal::try_from(1.234).unwrap().unordered_serialize()),
938 Decimal::try_from(1.234).unwrap(),
939 );
940 assert_eq!(
941 Decimal::unordered_deserialize(Decimal::from(1u8).unordered_serialize()),
942 Decimal::from(1u8),
943 );
944 assert_eq!(
945 Decimal::unordered_deserialize(Decimal::from(1i8).unordered_serialize()),
946 Decimal::from(1i8),
947 );
948 assert_eq!(
949 Decimal::unordered_deserialize(Decimal::from(1u16).unordered_serialize()),
950 Decimal::from(1u16),
951 );
952 assert_eq!(
953 Decimal::unordered_deserialize(Decimal::from(1i16).unordered_serialize()),
954 Decimal::from(1i16),
955 );
956 assert_eq!(
957 Decimal::unordered_deserialize(Decimal::from(1u32).unordered_serialize()),
958 Decimal::from(1u32),
959 );
960 assert_eq!(
961 Decimal::unordered_deserialize(Decimal::from(1i32).unordered_serialize()),
962 Decimal::from(1i32),
963 );
964 assert_eq!(
965 Decimal::unordered_deserialize(
966 Decimal::try_from(f64::NAN).unwrap().unordered_serialize()
967 ),
968 Decimal::try_from(f64::NAN).unwrap(),
969 );
970 assert_eq!(
971 Decimal::unordered_deserialize(
972 Decimal::try_from(f64::INFINITY)
973 .unwrap()
974 .unordered_serialize()
975 ),
976 Decimal::try_from(f64::INFINITY).unwrap(),
977 );
978 assert_eq!(u8::try_from(Decimal::from(1u8)).unwrap(), 1,);
979 assert_eq!(i8::try_from(Decimal::from(1i8)).unwrap(), 1,);
980 assert_eq!(u16::try_from(Decimal::from(1u16)).unwrap(), 1,);
981 assert_eq!(i16::try_from(Decimal::from(1i16)).unwrap(), 1,);
982 assert_eq!(u32::try_from(Decimal::from(1u32)).unwrap(), 1,);
983 assert_eq!(i32::try_from(Decimal::from(1i32)).unwrap(), 1,);
984 assert_eq!(u64::try_from(Decimal::from(1u64)).unwrap(), 1,);
985 assert_eq!(i64::try_from(Decimal::from(1i64)).unwrap(), 1,);
986 }
987
988 #[test]
989 fn test_order() {
990 let ordered = ["-inf", "-1", "0.00", "0.5", "2", "10", "inf", "nan"]
991 .iter()
992 .map(|s| Decimal::from_str(s).unwrap())
993 .collect_vec();
994 for i in 1..ordered.len() {
995 assert!(ordered[i - 1] < ordered[i]);
996 assert!(
997 memcomparable::Decimal::from(ordered[i - 1])
998 < memcomparable::Decimal::from(ordered[i])
999 );
1000 }
1001 }
1002
1003 #[test]
1004 fn test_decimal_estimate_size() {
1005 let decimal = Decimal::NegativeInf;
1006 assert_eq!(decimal.estimated_size(), 20);
1007
1008 let decimal = Decimal::Normalized(RustDecimal::try_from(1.0).unwrap());
1009 assert_eq!(decimal.estimated_size(), 20);
1010 }
1011}