1use std::fmt::Debug;
16use std::io::{Cursor, Read, Write};
17use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
18
19use byteorder::{BigEndian, ReadBytesExt};
20use bytes::{BufMut, BytesMut};
21use num_traits::{
22 CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Zero,
23};
24use postgres_types::{FromSql, IsNull, ToSql, Type, accepts, to_sql_checked};
25use risingwave_common_estimate_size::ZeroHeapSize;
26use rust_decimal::prelude::FromStr;
27use rust_decimal::{Decimal as RustDecimal, Error, MathematicalOps as _, RoundingStrategy};
28
29use super::DataType;
30use super::to_text::ToText;
31use crate::array::ArrayResult;
32use crate::types::Decimal::Normalized;
33use crate::types::ordered_float::OrderedFloat;
34
35#[derive(Debug, Copy, parse_display::Display, Clone, PartialEq, Hash, Eq, Ord, PartialOrd)]
36pub enum Decimal {
37 #[display("-Infinity")]
38 NegativeInf,
39 #[display("{0}")]
40 Normalized(RustDecimal),
41 #[display("Infinity")]
42 PositiveInf,
43 #[display("NaN")]
44 NaN,
45}
46
47impl ZeroHeapSize for Decimal {}
48
49impl ToText for Decimal {
50 fn write<W: std::fmt::Write>(&self, f: &mut W) -> std::fmt::Result {
51 write!(f, "{self}")
52 }
53
54 fn write_with_type<W: std::fmt::Write>(&self, ty: &DataType, f: &mut W) -> std::fmt::Result {
55 match ty {
56 DataType::Decimal => self.write(f),
57 _ => unreachable!(),
58 }
59 }
60}
61
62impl Decimal {
63 pub fn to_protobuf(self, output: &mut impl Write) -> ArrayResult<usize> {
65 let buf = self.unordered_serialize();
66 output.write_all(&buf)?;
67 Ok(buf.len())
68 }
69
70 pub fn from_protobuf(input: &mut impl Read) -> ArrayResult<Self> {
72 let mut buf = [0u8; 16];
73 input.read_exact(&mut buf)?;
74 Ok(Self::unordered_deserialize(buf))
75 }
76
77 pub fn from_scientific(value: &str) -> Option<Self> {
78 let decimal = RustDecimal::from_scientific(value).ok()?;
79 Some(Normalized(decimal))
80 }
81
82 pub fn from_str_radix(s: &str, radix: u32) -> rust_decimal::Result<Self> {
83 match s.to_ascii_lowercase().as_str() {
84 "nan" => Ok(Decimal::NaN),
85 "inf" | "+inf" | "infinity" | "+infinity" => Ok(Decimal::PositiveInf),
86 "-inf" | "-infinity" => Ok(Decimal::NegativeInf),
87 s => RustDecimal::from_str_radix(s, radix).map(Decimal::Normalized),
88 }
89 }
90}
91
92impl ToSql for Decimal {
93 accepts!(NUMERIC);
94
95 to_sql_checked!();
96
97 fn to_sql(
98 &self,
99 ty: &Type,
100 out: &mut BytesMut,
101 ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
102 where
103 Self: Sized,
104 {
105 match self {
106 Decimal::Normalized(d) => {
107 return d.to_sql(ty, out);
108 }
109 Decimal::NaN => {
110 out.reserve(8);
111 out.put_u16(0);
112 out.put_i16(0);
113 out.put_u16(0xC000);
114 out.put_i16(0);
115 }
116 Decimal::PositiveInf => {
117 out.reserve(8);
118 out.put_u16(0);
119 out.put_i16(0);
120 out.put_u16(0xD000);
121 out.put_i16(0);
122 }
123 Decimal::NegativeInf => {
124 out.reserve(8);
125 out.put_u16(0);
126 out.put_i16(0);
127 out.put_u16(0xF000);
128 out.put_i16(0);
129 }
130 }
131 Ok(IsNull::No)
132 }
133}
134
135impl<'a> FromSql<'a> for Decimal {
136 fn from_sql(
137 ty: &Type,
138 raw: &'a [u8],
139 ) -> Result<Self, Box<dyn std::error::Error + 'static + Sync + Send>> {
140 let mut rdr = Cursor::new(raw);
141 let _n_digits = rdr.read_u16::<BigEndian>()?;
142 let _weight = rdr.read_i16::<BigEndian>()?;
143 let sign = rdr.read_u16::<BigEndian>()?;
144 match sign {
145 0xC000 => Ok(Self::NaN),
146 0xD000 => Ok(Self::PositiveInf),
147 0xF000 => Ok(Self::NegativeInf),
148 _ => RustDecimal::from_sql(ty, raw).map(Self::Normalized),
149 }
150 }
151
152 fn accepts(ty: &Type) -> bool {
153 matches!(*ty, Type::NUMERIC)
154 }
155}
156
157macro_rules! impl_convert_int {
158 ($T:ty) => {
159 impl core::convert::From<$T> for Decimal {
160 #[inline]
161 fn from(t: $T) -> Self {
162 Self::Normalized(t.into())
163 }
164 }
165
166 impl core::convert::TryFrom<Decimal> for $T {
167 type Error = Error;
168
169 #[inline]
170 fn try_from(d: Decimal) -> Result<Self, Self::Error> {
171 match d.round_dp_ties_away(0) {
172 Decimal::Normalized(d) => d.try_into(),
173 _ => Err(Error::ConversionTo(std::any::type_name::<$T>().into())),
174 }
175 }
176 }
177 };
178}
179
180macro_rules! impl_convert_float {
181 ($T:ty) => {
182 impl core::convert::TryFrom<$T> for Decimal {
183 type Error = Error;
184
185 fn try_from(num: $T) -> Result<Self, Self::Error> {
186 match num {
187 num if num.is_nan() => Ok(Decimal::NaN),
188 num if num.is_infinite() && num.is_sign_positive() => Ok(Decimal::PositiveInf),
189 num if num.is_infinite() && num.is_sign_negative() => Ok(Decimal::NegativeInf),
190 num => num.try_into().map(Decimal::Normalized),
191 }
192 }
193 }
194 impl core::convert::TryFrom<OrderedFloat<$T>> for Decimal {
195 type Error = Error;
196
197 fn try_from(value: OrderedFloat<$T>) -> Result<Self, Self::Error> {
198 value.0.try_into()
199 }
200 }
201
202 impl core::convert::TryFrom<Decimal> for $T {
203 type Error = Error;
204
205 fn try_from(d: Decimal) -> Result<Self, Self::Error> {
206 match d {
207 Decimal::Normalized(d) => d.try_into(),
208 Decimal::NaN => Ok(<$T>::NAN),
209 Decimal::PositiveInf => Ok(<$T>::INFINITY),
210 Decimal::NegativeInf => Ok(<$T>::NEG_INFINITY),
211 }
212 }
213 }
214 impl core::convert::TryFrom<Decimal> for OrderedFloat<$T> {
215 type Error = Error;
216
217 fn try_from(d: Decimal) -> Result<Self, Self::Error> {
218 d.try_into().map(Self)
219 }
220 }
221 };
222}
223
224macro_rules! checked_proxy {
225 ($trait:ty, $func:ident, $op: tt) => {
226 impl $trait for Decimal {
227 fn $func(&self, other: &Self) -> Option<Self> {
228 match (self, other) {
229 (Self::Normalized(lhs), Self::Normalized(rhs)) => {
230 lhs.$func(rhs).map(Decimal::Normalized)
231 }
232 (lhs, rhs) => Some(*lhs $op *rhs),
233 }
234 }
235 }
236 }
237}
238
239impl_convert_float!(f32);
240impl_convert_float!(f64);
241
242impl_convert_int!(isize);
243impl_convert_int!(i8);
244impl_convert_int!(i16);
245impl_convert_int!(i32);
246impl_convert_int!(i64);
247impl_convert_int!(usize);
248impl_convert_int!(u8);
249impl_convert_int!(u16);
250impl_convert_int!(u32);
251impl_convert_int!(u64);
252
253checked_proxy!(CheckedRem, checked_rem, %);
254checked_proxy!(CheckedSub, checked_sub, -);
255checked_proxy!(CheckedAdd, checked_add, +);
256checked_proxy!(CheckedDiv, checked_div, /);
257checked_proxy!(CheckedMul, checked_mul, *);
258
259impl Add for Decimal {
260 type Output = Self;
261
262 fn add(self, other: Self) -> Self {
263 match (self, other) {
264 (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs + rhs),
265 (Self::NaN, _) => Self::NaN,
266 (_, Self::NaN) => Self::NaN,
267 (Self::PositiveInf, Self::NegativeInf) => Self::NaN,
268 (Self::NegativeInf, Self::PositiveInf) => Self::NaN,
269 (Self::PositiveInf, _) => Self::PositiveInf,
270 (_, Self::PositiveInf) => Self::PositiveInf,
271 (Self::NegativeInf, _) => Self::NegativeInf,
272 (_, Self::NegativeInf) => Self::NegativeInf,
273 }
274 }
275}
276
277impl Neg for Decimal {
278 type Output = Self;
279
280 fn neg(self) -> Self {
281 match self {
282 Self::Normalized(d) => Self::Normalized(-d),
283 Self::NaN => Self::NaN,
284 Self::PositiveInf => Self::NegativeInf,
285 Self::NegativeInf => Self::PositiveInf,
286 }
287 }
288}
289
290impl CheckedNeg for Decimal {
291 fn checked_neg(&self) -> Option<Self> {
292 match self {
293 Self::Normalized(d) => Some(Self::Normalized(-d)),
294 Self::NaN => Some(Self::NaN),
295 Self::PositiveInf => Some(Self::NegativeInf),
296 Self::NegativeInf => Some(Self::PositiveInf),
297 }
298 }
299}
300
301impl Rem for Decimal {
302 type Output = Self;
303
304 fn rem(self, other: Self) -> Self {
305 match (self, other) {
306 (Self::Normalized(lhs), Self::Normalized(rhs)) if !rhs.is_zero() => {
307 Self::Normalized(lhs % rhs)
308 }
309 (Self::Normalized(_), Self::Normalized(_)) => Self::NaN,
310 (Self::Normalized(lhs), Self::PositiveInf)
311 if lhs.is_sign_positive() || lhs.is_zero() =>
312 {
313 Self::Normalized(lhs)
314 }
315 (Self::Normalized(d), Self::PositiveInf) => Self::Normalized(d),
316 (Self::Normalized(lhs), Self::NegativeInf)
317 if lhs.is_sign_negative() || lhs.is_zero() =>
318 {
319 Self::Normalized(lhs)
320 }
321 (Self::Normalized(d), Self::NegativeInf) => Self::Normalized(d),
322 _ => Self::NaN,
323 }
324 }
325}
326
327impl Div for Decimal {
328 type Output = Self;
329
330 fn div(self, other: Self) -> Self {
331 match (self, other) {
332 (Self::NaN, _) => Self::NaN,
334 (_, Self::NaN) => Self::NaN,
335 (lhs, Self::Normalized(rhs)) if rhs.is_zero() => match lhs {
337 Self::Normalized(lhs) => {
338 if lhs.is_sign_positive() && !lhs.is_zero() {
339 Self::PositiveInf
340 } else if lhs.is_sign_negative() && !lhs.is_zero() {
341 Self::NegativeInf
342 } else {
343 Self::NaN
344 }
345 }
346 Self::PositiveInf => Self::PositiveInf,
347 Self::NegativeInf => Self::NegativeInf,
348 _ => unreachable!(),
349 },
350 (Self::Normalized(_), Self::PositiveInf) => Self::Normalized(RustDecimal::from(0)),
352 (_, Self::PositiveInf) => Self::NaN,
353 (Self::Normalized(_), Self::NegativeInf) => Self::Normalized(RustDecimal::from(0)),
354 (_, Self::NegativeInf) => Self::NaN,
355 (Self::PositiveInf, Self::Normalized(d)) if d.is_sign_positive() => Self::PositiveInf,
357 (Self::PositiveInf, Self::Normalized(d)) if d.is_sign_negative() => Self::NegativeInf,
358 (Self::NegativeInf, Self::Normalized(d)) if d.is_sign_positive() => Self::NegativeInf,
359 (Self::NegativeInf, Self::Normalized(d)) if d.is_sign_negative() => Self::PositiveInf,
360 (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs / rhs),
362 _ => unreachable!(),
363 }
364 }
365}
366
367impl Mul for Decimal {
368 type Output = Self;
369
370 fn mul(self, other: Self) -> Self {
371 match (self, other) {
372 (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs * rhs),
373 (Self::NaN, _) => Self::NaN,
374 (_, Self::NaN) => Self::NaN,
375 (Self::PositiveInf, Self::Normalized(rhs))
376 if !rhs.is_zero() && rhs.is_sign_negative() =>
377 {
378 Self::NegativeInf
379 }
380 (Self::PositiveInf, Self::Normalized(rhs))
381 if !rhs.is_zero() && rhs.is_sign_positive() =>
382 {
383 Self::PositiveInf
384 }
385 (Self::PositiveInf, Self::PositiveInf) => Self::PositiveInf,
386 (Self::PositiveInf, Self::NegativeInf) => Self::NegativeInf,
387 (Self::Normalized(lhs), Self::PositiveInf)
388 if !lhs.is_zero() && lhs.is_sign_negative() =>
389 {
390 Self::NegativeInf
391 }
392 (Self::Normalized(lhs), Self::PositiveInf)
393 if !lhs.is_zero() && lhs.is_sign_positive() =>
394 {
395 Self::PositiveInf
396 }
397 (Self::NegativeInf, Self::PositiveInf) => Self::NegativeInf,
398 (Self::NegativeInf, Self::Normalized(rhs))
399 if !rhs.is_zero() && rhs.is_sign_negative() =>
400 {
401 Self::PositiveInf
402 }
403 (Self::NegativeInf, Self::Normalized(rhs))
404 if !rhs.is_zero() && rhs.is_sign_positive() =>
405 {
406 Self::NegativeInf
407 }
408 (Self::NegativeInf, Self::NegativeInf) => Self::PositiveInf,
409 (Self::Normalized(lhs), Self::NegativeInf)
410 if !lhs.is_zero() && lhs.is_sign_negative() =>
411 {
412 Self::PositiveInf
413 }
414 (Self::Normalized(lhs), Self::NegativeInf)
415 if !lhs.is_zero() && lhs.is_sign_positive() =>
416 {
417 Self::NegativeInf
418 }
419 _ => Self::NaN,
421 }
422 }
423}
424
425impl Sub for Decimal {
426 type Output = Self;
427
428 fn sub(self, other: Self) -> Self {
429 match (self, other) {
430 (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs - rhs),
431 (Self::NaN, _) => Self::NaN,
432 (_, Self::NaN) => Self::NaN,
433 (Self::PositiveInf, Self::PositiveInf) => Self::NaN,
434 (Self::NegativeInf, Self::NegativeInf) => Self::NaN,
435 (Self::PositiveInf, _) => Self::PositiveInf,
436 (_, Self::PositiveInf) => Self::NegativeInf,
437 (Self::NegativeInf, _) => Self::NegativeInf,
438 (_, Self::NegativeInf) => Self::PositiveInf,
439 }
440 }
441}
442
443impl Decimal {
444 pub const MAX_PRECISION: u8 = 28;
445
446 pub fn scale(&self) -> Option<i32> {
447 let Decimal::Normalized(d) = self else {
448 return None;
449 };
450 Some(d.scale() as _)
451 }
452
453 pub fn rescale(&mut self, scale: u32) {
454 if let Normalized(a) = self {
455 a.rescale(scale);
456 }
457 }
458
459 #[must_use]
460 pub fn round_dp_ties_away(&self, dp: u32) -> Self {
461 match self {
462 Self::Normalized(d) => {
463 let new_d = d.round_dp_with_strategy(dp, RoundingStrategy::MidpointAwayFromZero);
464 Self::Normalized(new_d)
465 }
466 d => *d,
467 }
468 }
469
470 #[must_use]
472 pub fn round_left_ties_away(&self, left: u32) -> Option<Self> {
473 let &Self::Normalized(mut d) = self else {
474 return Some(*self);
475 };
476
477 let old_scale = d.scale();
480 let new_scale = old_scale.saturating_add(left);
481 const MANTISSA_UP: i128 = 5 * 10i128.pow(Decimal::MAX_PRECISION as _);
482 let d = match new_scale.cmp(&Self::MAX_PRECISION.add(1).into()) {
483 std::cmp::Ordering::Less => {
485 d.set_scale(new_scale).unwrap();
486 d.round_dp_with_strategy(0, RoundingStrategy::MidpointAwayFromZero)
487 }
488 std::cmp::Ordering::Equal => (d.mantissa() / MANTISSA_UP).signum().into(),
490 std::cmp::Ordering::Greater => 0.into(),
492 };
493
494 match left > Decimal::MAX_PRECISION.into() {
497 true => d.is_zero().then(|| 0.into()),
498 false => d
499 .checked_mul(RustDecimal::from_i128_with_scale(10i128.pow(left), 0))
500 .map(Self::Normalized),
501 }
502 }
503
504 #[must_use]
505 pub fn ceil(&self) -> Self {
506 match self {
507 Self::Normalized(d) => {
508 let mut d = d.ceil();
509 if d.is_zero() {
510 d.set_sign_positive(true);
511 }
512 Self::Normalized(d)
513 }
514 d => *d,
515 }
516 }
517
518 #[must_use]
519 pub fn floor(&self) -> Self {
520 match self {
521 Self::Normalized(d) => Self::Normalized(d.floor()),
522 d => *d,
523 }
524 }
525
526 #[must_use]
527 pub fn trunc(&self) -> Self {
528 match self {
529 Self::Normalized(d) => {
530 let mut d = d.trunc();
531 if d.is_zero() {
532 d.set_sign_positive(true);
533 }
534 Self::Normalized(d)
535 }
536 d => *d,
537 }
538 }
539
540 #[must_use]
541 pub fn round_ties_even(&self) -> Self {
542 match self {
543 Self::Normalized(d) => Self::Normalized(d.round()),
544 d => *d,
545 }
546 }
547
548 pub fn from_i128_with_scale(num: i128, scale: u32) -> Self {
549 Decimal::Normalized(RustDecimal::from_i128_with_scale(num, scale))
550 }
551
552 #[must_use]
553 pub fn normalize(&self) -> Self {
554 match self {
555 Self::Normalized(d) => Self::Normalized(d.normalize()),
556 d => *d,
557 }
558 }
559
560 pub fn unordered_serialize(&self) -> [u8; 16] {
561 match self {
564 Self::Normalized(d) => d.serialize(),
565 Self::NaN => [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
566 Self::PositiveInf => [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
567 Self::NegativeInf => [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
568 }
569 }
570
571 pub fn unordered_deserialize(bytes: [u8; 16]) -> Self {
572 match bytes[0] {
573 0u8 => Self::Normalized(RustDecimal::deserialize(bytes)),
574 1u8 => Self::NaN,
575 2u8 => Self::PositiveInf,
576 3u8 => Self::NegativeInf,
577 _ => unreachable!(),
578 }
579 }
580
581 pub fn abs(&self) -> Self {
582 match self {
583 Self::Normalized(d) => {
584 if d.is_sign_negative() {
585 Self::Normalized(-d)
586 } else {
587 Self::Normalized(*d)
588 }
589 }
590 Self::NaN => Self::NaN,
591 Self::PositiveInf => Self::PositiveInf,
592 Self::NegativeInf => Self::PositiveInf,
593 }
594 }
595
596 pub fn sign(&self) -> Self {
597 match self {
598 Self::NaN => Self::NaN,
599 _ => match self.cmp(&0.into()) {
600 std::cmp::Ordering::Less => (-1).into(),
601 std::cmp::Ordering::Equal => 0.into(),
602 std::cmp::Ordering::Greater => 1.into(),
603 },
604 }
605 }
606
607 pub fn checked_exp(&self) -> Option<Decimal> {
608 match self {
609 Self::Normalized(d) => d.checked_exp().map(Self::Normalized),
610 Self::NaN => Some(Self::NaN),
611 Self::PositiveInf => Some(Self::PositiveInf),
612 Self::NegativeInf => Some(Self::zero()),
613 }
614 }
615
616 pub fn checked_ln(&self) -> Option<Decimal> {
617 match self {
618 Self::Normalized(d) => d.checked_ln().map(Self::Normalized),
619 Self::NaN => Some(Self::NaN),
620 Self::PositiveInf => Some(Self::PositiveInf),
621 Self::NegativeInf => None,
622 }
623 }
624
625 pub fn checked_log10(&self) -> Option<Decimal> {
626 match self {
627 Self::Normalized(d) => d.checked_log10().map(Self::Normalized),
628 Self::NaN => Some(Self::NaN),
629 Self::PositiveInf => Some(Self::PositiveInf),
630 Self::NegativeInf => None,
631 }
632 }
633
634 pub fn checked_powd(&self, rhs: &Self) -> Result<Self, PowError> {
635 use std::cmp::Ordering;
636
637 match (self, rhs) {
638 (Decimal::NaN, Decimal::NaN)
640 | (Decimal::PositiveInf, Decimal::NaN)
641 | (Decimal::NegativeInf, Decimal::NaN)
642 | (Decimal::NaN, Decimal::PositiveInf)
643 | (Decimal::NaN, Decimal::NegativeInf) => Ok(Self::NaN),
644 (Normalized(lhs), Decimal::NaN) => match lhs.is_one() {
645 true => Ok(1.into()),
646 false => Ok(Self::NaN),
647 },
648 (Decimal::NaN, Normalized(rhs)) => match rhs.is_zero() {
649 true => Ok(1.into()),
650 false => Ok(Self::NaN),
651 },
652
653 (Normalized(lhs), Decimal::PositiveInf) => match lhs.abs().cmp(&1.into()) {
655 Ordering::Greater => Ok(Self::PositiveInf),
656 Ordering::Equal => Ok(1.into()),
657 Ordering::Less => Ok(0.into()),
658 },
659 (Decimal::PositiveInf, Decimal::PositiveInf)
662 | (Decimal::NegativeInf, Decimal::PositiveInf) => Ok(Self::PositiveInf),
663
664 (Normalized(lhs), Decimal::NegativeInf) => match lhs.abs().cmp(&1.into()) {
666 Ordering::Greater => Ok(0.into()),
667 Ordering::Equal => Ok(1.into()),
668 Ordering::Less => match lhs.is_zero() {
669 true => Err(PowError::ZeroNegative),
671 false => Ok(Self::PositiveInf),
672 },
673 },
674 (Decimal::PositiveInf, Decimal::NegativeInf)
675 | (Decimal::NegativeInf, Decimal::NegativeInf) => Ok(0.into()),
676
677 (Decimal::PositiveInf, Normalized(rhs)) => match rhs.cmp(&0.into()) {
679 Ordering::Greater => Ok(Self::PositiveInf),
680 Ordering::Equal => Ok(1.into()),
681 Ordering::Less => Ok(0.into()),
682 },
683
684 (Decimal::NegativeInf, Normalized(rhs)) => match !rhs.fract().is_zero() {
686 true => Err(PowError::NegativeFract),
688 false => match (rhs.cmp(&0.into()), rhs.rem(&2.into()).abs().is_one()) {
689 (Ordering::Greater, true) => Ok(Self::NegativeInf),
690 (Ordering::Greater, false) => Ok(Self::PositiveInf),
691 (Ordering::Equal, true) => unreachable!(),
692 (Ordering::Equal, false) => Ok(1.into()),
693 (Ordering::Less, true) => Ok(0.into()), (Ordering::Less, false) => Ok(0.into()),
695 },
696 },
697
698 (Normalized(lhs), Normalized(rhs)) => {
700 if lhs.is_zero() && rhs < &0.into() {
701 return Err(PowError::ZeroNegative);
702 }
703 if lhs < &0.into() && !rhs.fract().is_zero() {
704 return Err(PowError::NegativeFract);
705 }
706 match lhs.checked_powd(*rhs) {
707 Some(d) => Ok(Self::Normalized(d)),
708 None => Err(PowError::Overflow),
709 }
710 }
711 }
712 }
713}
714
715pub enum PowError {
716 ZeroNegative,
717 NegativeFract,
718 Overflow,
719}
720
721impl From<Decimal> for memcomparable::Decimal {
722 fn from(d: Decimal) -> Self {
723 match d {
724 Decimal::Normalized(d) => Self::Normalized(d),
725 Decimal::PositiveInf => Self::Inf,
726 Decimal::NegativeInf => Self::NegInf,
727 Decimal::NaN => Self::NaN,
728 }
729 }
730}
731
732impl From<memcomparable::Decimal> for Decimal {
733 fn from(d: memcomparable::Decimal) -> Self {
734 match d {
735 memcomparable::Decimal::Normalized(d) => Self::Normalized(d),
736 memcomparable::Decimal::Inf => Self::PositiveInf,
737 memcomparable::Decimal::NegInf => Self::NegativeInf,
738 memcomparable::Decimal::NaN => Self::NaN,
739 }
740 }
741}
742
743impl Default for Decimal {
744 fn default() -> Self {
745 Self::Normalized(RustDecimal::default())
746 }
747}
748
749impl FromStr for Decimal {
750 type Err = Error;
751
752 fn from_str(s: &str) -> Result<Self, Self::Err> {
753 match s.to_ascii_lowercase().as_str() {
754 "nan" => Ok(Decimal::NaN),
755 "inf" | "+inf" | "infinity" | "+infinity" => Ok(Decimal::PositiveInf),
756 "-inf" | "-infinity" => Ok(Decimal::NegativeInf),
757 s => RustDecimal::from_str(s).map(Decimal::Normalized),
758 }
759 }
760}
761
762impl Zero for Decimal {
763 fn zero() -> Self {
764 Self::Normalized(RustDecimal::zero())
765 }
766
767 fn is_zero(&self) -> bool {
768 if let Self::Normalized(d) = self {
769 d.is_zero()
770 } else {
771 false
772 }
773 }
774}
775
776impl One for Decimal {
777 fn one() -> Self {
778 Self::Normalized(RustDecimal::one())
779 }
780}
781
782impl Num for Decimal {
783 type FromStrRadixErr = Error;
784
785 fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
786 if str.eq_ignore_ascii_case("inf") || str.eq_ignore_ascii_case("infinity") {
787 Ok(Self::PositiveInf)
788 } else if str.eq_ignore_ascii_case("-inf") || str.eq_ignore_ascii_case("-infinity") {
789 Ok(Self::NegativeInf)
790 } else if str.eq_ignore_ascii_case("nan") {
791 Ok(Self::NaN)
792 } else {
793 RustDecimal::from_str_radix(str, radix).map(Decimal::Normalized)
794 }
795 }
796}
797
798impl From<RustDecimal> for Decimal {
799 fn from(d: RustDecimal) -> Self {
800 Self::Normalized(d)
801 }
802}
803
804#[cfg(test)]
805mod tests {
806 use itertools::Itertools as _;
807 use risingwave_common_estimate_size::EstimateSize;
808
809 use super::*;
810 use crate::util::iter_util::ZipEqFast;
811
812 fn check(lhs: f32, rhs: f32) -> bool {
813 if lhs.is_nan() && rhs.is_nan() {
814 true
815 } else if lhs.is_infinite() && rhs.is_infinite() {
816 if lhs.is_sign_positive() && rhs.is_sign_positive() {
817 true
818 } else {
819 lhs.is_sign_negative() && rhs.is_sign_negative()
820 }
821 } else if lhs.is_finite() && rhs.is_finite() {
822 lhs == rhs
823 } else {
824 false
825 }
826 }
827
828 #[test]
829 fn check_op_with_float() {
830 let decimals = [
831 Decimal::NaN,
832 Decimal::PositiveInf,
833 Decimal::NegativeInf,
834 Decimal::try_from(1.0).unwrap(),
835 Decimal::try_from(-1.0).unwrap(),
836 Decimal::try_from(0.0).unwrap(),
837 ];
838 let floats = [
839 f32::NAN,
840 f32::INFINITY,
841 f32::NEG_INFINITY,
842 1.0f32,
843 -1.0f32,
844 0.0f32,
845 ];
846 for (d_lhs, f_lhs) in decimals.iter().zip_eq_fast(floats.iter()) {
847 for (d_rhs, f_rhs) in decimals.iter().zip_eq_fast(floats.iter()) {
848 assert!(check((*d_lhs + *d_rhs).try_into().unwrap(), f_lhs + f_rhs));
849 assert!(check((*d_lhs - *d_rhs).try_into().unwrap(), f_lhs - f_rhs));
850 assert!(check((*d_lhs * *d_rhs).try_into().unwrap(), f_lhs * f_rhs));
851 assert!(check((*d_lhs / *d_rhs).try_into().unwrap(), f_lhs / f_rhs));
852 assert!(check((*d_lhs % *d_rhs).try_into().unwrap(), f_lhs % f_rhs));
853 }
854 }
855 }
856
857 #[test]
858 fn basic_test() {
859 assert_eq!(Decimal::from_str("nan").unwrap(), Decimal::NaN,);
860 assert_eq!(Decimal::from_str("NaN").unwrap(), Decimal::NaN,);
861 assert_eq!(Decimal::from_str("NAN").unwrap(), Decimal::NaN,);
862 assert_eq!(Decimal::from_str("nAn").unwrap(), Decimal::NaN,);
863 assert_eq!(Decimal::from_str("nAN").unwrap(), Decimal::NaN,);
864 assert_eq!(Decimal::from_str("Nan").unwrap(), Decimal::NaN,);
865 assert_eq!(Decimal::from_str("NAn").unwrap(), Decimal::NaN,);
866
867 assert_eq!(Decimal::from_str("inf").unwrap(), Decimal::PositiveInf,);
868 assert_eq!(Decimal::from_str("INF").unwrap(), Decimal::PositiveInf,);
869 assert_eq!(Decimal::from_str("iNF").unwrap(), Decimal::PositiveInf,);
870 assert_eq!(Decimal::from_str("inF").unwrap(), Decimal::PositiveInf,);
871 assert_eq!(Decimal::from_str("InF").unwrap(), Decimal::PositiveInf,);
872 assert_eq!(Decimal::from_str("INf").unwrap(), Decimal::PositiveInf,);
873 assert_eq!(Decimal::from_str("+inf").unwrap(), Decimal::PositiveInf,);
874 assert_eq!(Decimal::from_str("+INF").unwrap(), Decimal::PositiveInf,);
875 assert_eq!(Decimal::from_str("+Inf").unwrap(), Decimal::PositiveInf,);
876 assert_eq!(Decimal::from_str("+iNF").unwrap(), Decimal::PositiveInf,);
877 assert_eq!(Decimal::from_str("+inF").unwrap(), Decimal::PositiveInf,);
878 assert_eq!(Decimal::from_str("+InF").unwrap(), Decimal::PositiveInf,);
879 assert_eq!(Decimal::from_str("+INf").unwrap(), Decimal::PositiveInf,);
880 assert_eq!(Decimal::from_str("inFINity").unwrap(), Decimal::PositiveInf,);
881 assert_eq!(
882 Decimal::from_str("+infiNIty").unwrap(),
883 Decimal::PositiveInf,
884 );
885
886 assert_eq!(Decimal::from_str("-inf").unwrap(), Decimal::NegativeInf,);
887 assert_eq!(Decimal::from_str("-INF").unwrap(), Decimal::NegativeInf,);
888 assert_eq!(Decimal::from_str("-Inf").unwrap(), Decimal::NegativeInf,);
889 assert_eq!(Decimal::from_str("-iNF").unwrap(), Decimal::NegativeInf,);
890 assert_eq!(Decimal::from_str("-inF").unwrap(), Decimal::NegativeInf,);
891 assert_eq!(Decimal::from_str("-InF").unwrap(), Decimal::NegativeInf,);
892 assert_eq!(Decimal::from_str("-INf").unwrap(), Decimal::NegativeInf,);
893 assert_eq!(
894 Decimal::from_str("-INfinity").unwrap(),
895 Decimal::NegativeInf,
896 );
897
898 assert_eq!(
899 Decimal::try_from(10.0).unwrap() / Decimal::PositiveInf,
900 Decimal::try_from(0.0).unwrap(),
901 );
902 assert_eq!(
903 Decimal::try_from(f32::INFINITY).unwrap(),
904 Decimal::PositiveInf
905 );
906 assert_eq!(Decimal::try_from(f64::NAN).unwrap(), Decimal::NaN);
907 assert_eq!(
908 Decimal::try_from(f64::INFINITY).unwrap(),
909 Decimal::PositiveInf
910 );
911 assert_eq!(
912 Decimal::unordered_deserialize(Decimal::try_from(1.234).unwrap().unordered_serialize()),
913 Decimal::try_from(1.234).unwrap(),
914 );
915 assert_eq!(
916 Decimal::unordered_deserialize(Decimal::from(1u8).unordered_serialize()),
917 Decimal::from(1u8),
918 );
919 assert_eq!(
920 Decimal::unordered_deserialize(Decimal::from(1i8).unordered_serialize()),
921 Decimal::from(1i8),
922 );
923 assert_eq!(
924 Decimal::unordered_deserialize(Decimal::from(1u16).unordered_serialize()),
925 Decimal::from(1u16),
926 );
927 assert_eq!(
928 Decimal::unordered_deserialize(Decimal::from(1i16).unordered_serialize()),
929 Decimal::from(1i16),
930 );
931 assert_eq!(
932 Decimal::unordered_deserialize(Decimal::from(1u32).unordered_serialize()),
933 Decimal::from(1u32),
934 );
935 assert_eq!(
936 Decimal::unordered_deserialize(Decimal::from(1i32).unordered_serialize()),
937 Decimal::from(1i32),
938 );
939 assert_eq!(
940 Decimal::unordered_deserialize(
941 Decimal::try_from(f64::NAN).unwrap().unordered_serialize()
942 ),
943 Decimal::try_from(f64::NAN).unwrap(),
944 );
945 assert_eq!(
946 Decimal::unordered_deserialize(
947 Decimal::try_from(f64::INFINITY)
948 .unwrap()
949 .unordered_serialize()
950 ),
951 Decimal::try_from(f64::INFINITY).unwrap(),
952 );
953 assert_eq!(u8::try_from(Decimal::from(1u8)).unwrap(), 1,);
954 assert_eq!(i8::try_from(Decimal::from(1i8)).unwrap(), 1,);
955 assert_eq!(u16::try_from(Decimal::from(1u16)).unwrap(), 1,);
956 assert_eq!(i16::try_from(Decimal::from(1i16)).unwrap(), 1,);
957 assert_eq!(u32::try_from(Decimal::from(1u32)).unwrap(), 1,);
958 assert_eq!(i32::try_from(Decimal::from(1i32)).unwrap(), 1,);
959 assert_eq!(u64::try_from(Decimal::from(1u64)).unwrap(), 1,);
960 assert_eq!(i64::try_from(Decimal::from(1i64)).unwrap(), 1,);
961 }
962
963 #[test]
964 fn test_order() {
965 let ordered = ["-inf", "-1", "0.00", "0.5", "2", "10", "inf", "nan"]
966 .iter()
967 .map(|s| Decimal::from_str(s).unwrap())
968 .collect_vec();
969 for i in 1..ordered.len() {
970 assert!(ordered[i - 1] < ordered[i]);
971 assert!(
972 memcomparable::Decimal::from(ordered[i - 1])
973 < memcomparable::Decimal::from(ordered[i])
974 );
975 }
976 }
977
978 #[test]
979 fn test_decimal_estimate_size() {
980 let decimal = Decimal::NegativeInf;
981 assert_eq!(decimal.estimated_size(), 20);
982
983 let decimal = Decimal::Normalized(RustDecimal::try_from(1.0).unwrap());
984 assert_eq!(decimal.estimated_size(), 20);
985 }
986}