risingwave_common/types/
decimal.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::io::{Cursor, Read, Write};
17use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
18
19use byteorder::{BigEndian, ReadBytesExt};
20use bytes::{BufMut, BytesMut};
21use num_traits::{
22    CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Zero,
23};
24use postgres_types::{FromSql, IsNull, ToSql, Type, accepts, to_sql_checked};
25use risingwave_common_estimate_size::ZeroHeapSize;
26use rust_decimal::prelude::FromStr;
27use rust_decimal::{Decimal as RustDecimal, Error, MathematicalOps as _, RoundingStrategy};
28
29use super::DataType;
30use super::to_text::ToText;
31use crate::array::ArrayResult;
32use crate::types::Decimal::Normalized;
33use crate::types::ordered_float::OrderedFloat;
34
35#[derive(Debug, Copy, parse_display::Display, Clone, PartialEq, Hash, Eq, Ord, PartialOrd)]
36pub enum Decimal {
37    #[display("-Infinity")]
38    NegativeInf,
39    #[display("{0}")]
40    Normalized(RustDecimal),
41    #[display("Infinity")]
42    PositiveInf,
43    #[display("NaN")]
44    NaN,
45}
46
47impl ZeroHeapSize for Decimal {}
48
49impl ToText for Decimal {
50    fn write<W: std::fmt::Write>(&self, f: &mut W) -> std::fmt::Result {
51        write!(f, "{self}")
52    }
53
54    fn write_with_type<W: std::fmt::Write>(&self, ty: &DataType, f: &mut W) -> std::fmt::Result {
55        match ty {
56            DataType::Decimal => self.write(f),
57            _ => unreachable!(),
58        }
59    }
60}
61
62impl Decimal {
63    /// Used by `PrimitiveArray` to serialize the array to protobuf.
64    pub fn to_protobuf(self, output: &mut impl Write) -> ArrayResult<usize> {
65        let buf = self.unordered_serialize();
66        output.write_all(&buf)?;
67        Ok(buf.len())
68    }
69
70    /// Used by `DecimalValueReader` to deserialize the array from protobuf.
71    pub fn from_protobuf(input: &mut impl Read) -> ArrayResult<Self> {
72        let mut buf = [0u8; 16];
73        input.read_exact(&mut buf)?;
74        Ok(Self::unordered_deserialize(buf))
75    }
76
77    pub fn from_scientific(value: &str) -> Option<Self> {
78        let decimal = RustDecimal::from_scientific(value).ok()?;
79        Some(Normalized(decimal))
80    }
81
82    pub fn from_str_radix(s: &str, radix: u32) -> rust_decimal::Result<Self> {
83        match s.to_ascii_lowercase().as_str() {
84            "nan" => Ok(Decimal::NaN),
85            "inf" | "+inf" | "infinity" | "+infinity" => Ok(Decimal::PositiveInf),
86            "-inf" | "-infinity" => Ok(Decimal::NegativeInf),
87            s => RustDecimal::from_str_radix(s, radix).map(Decimal::Normalized),
88        }
89    }
90}
91
92impl ToSql for Decimal {
93    accepts!(NUMERIC);
94
95    to_sql_checked!();
96
97    fn to_sql(
98        &self,
99        ty: &Type,
100        out: &mut BytesMut,
101    ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
102    where
103        Self: Sized,
104    {
105        match self {
106            Decimal::Normalized(d) => {
107                return d.to_sql(ty, out);
108            }
109            Decimal::NaN => {
110                out.reserve(8);
111                out.put_u16(0);
112                out.put_i16(0);
113                out.put_u16(0xC000);
114                out.put_i16(0);
115            }
116            Decimal::PositiveInf => {
117                out.reserve(8);
118                out.put_u16(0);
119                out.put_i16(0);
120                out.put_u16(0xD000);
121                out.put_i16(0);
122            }
123            Decimal::NegativeInf => {
124                out.reserve(8);
125                out.put_u16(0);
126                out.put_i16(0);
127                out.put_u16(0xF000);
128                out.put_i16(0);
129            }
130        }
131        Ok(IsNull::No)
132    }
133}
134
135impl<'a> FromSql<'a> for Decimal {
136    fn from_sql(
137        ty: &Type,
138        raw: &'a [u8],
139    ) -> Result<Self, Box<dyn std::error::Error + 'static + Sync + Send>> {
140        let mut rdr = Cursor::new(raw);
141        let _n_digits = rdr.read_u16::<BigEndian>()?;
142        let _weight = rdr.read_i16::<BigEndian>()?;
143        let sign = rdr.read_u16::<BigEndian>()?;
144        match sign {
145            0xC000 => Ok(Self::NaN),
146            0xD000 => Ok(Self::PositiveInf),
147            0xF000 => Ok(Self::NegativeInf),
148            _ => RustDecimal::from_sql(ty, raw).map(Self::Normalized),
149        }
150    }
151
152    fn accepts(ty: &Type) -> bool {
153        matches!(*ty, Type::NUMERIC)
154    }
155}
156
157macro_rules! impl_convert_int {
158    ($T:ty) => {
159        impl core::convert::From<$T> for Decimal {
160            #[inline]
161            fn from(t: $T) -> Self {
162                Self::Normalized(t.into())
163            }
164        }
165
166        impl core::convert::TryFrom<Decimal> for $T {
167            type Error = Error;
168
169            #[inline]
170            fn try_from(d: Decimal) -> Result<Self, Self::Error> {
171                match d.round_dp_ties_away(0) {
172                    Decimal::Normalized(d) => d.try_into(),
173                    _ => Err(Error::ConversionTo(std::any::type_name::<$T>().into())),
174                }
175            }
176        }
177    };
178}
179
180macro_rules! impl_convert_float {
181    ($T:ty) => {
182        impl core::convert::TryFrom<$T> for Decimal {
183            type Error = Error;
184
185            fn try_from(num: $T) -> Result<Self, Self::Error> {
186                match num {
187                    num if num.is_nan() => Ok(Decimal::NaN),
188                    num if num.is_infinite() && num.is_sign_positive() => Ok(Decimal::PositiveInf),
189                    num if num.is_infinite() && num.is_sign_negative() => Ok(Decimal::NegativeInf),
190                    num => num.try_into().map(Decimal::Normalized),
191                }
192            }
193        }
194        impl core::convert::TryFrom<OrderedFloat<$T>> for Decimal {
195            type Error = Error;
196
197            fn try_from(value: OrderedFloat<$T>) -> Result<Self, Self::Error> {
198                value.0.try_into()
199            }
200        }
201
202        impl core::convert::TryFrom<Decimal> for $T {
203            type Error = Error;
204
205            fn try_from(d: Decimal) -> Result<Self, Self::Error> {
206                match d {
207                    Decimal::Normalized(d) => d.try_into(),
208                    Decimal::NaN => Ok(<$T>::NAN),
209                    Decimal::PositiveInf => Ok(<$T>::INFINITY),
210                    Decimal::NegativeInf => Ok(<$T>::NEG_INFINITY),
211                }
212            }
213        }
214        impl core::convert::TryFrom<Decimal> for OrderedFloat<$T> {
215            type Error = Error;
216
217            fn try_from(d: Decimal) -> Result<Self, Self::Error> {
218                d.try_into().map(Self)
219            }
220        }
221    };
222}
223
224macro_rules! checked_proxy {
225    ($trait:ty, $func:ident, $op: tt) => {
226        impl $trait for Decimal {
227            fn $func(&self, other: &Self) -> Option<Self> {
228                match (self, other) {
229                    (Self::Normalized(lhs), Self::Normalized(rhs)) => {
230                        lhs.$func(rhs).map(Decimal::Normalized)
231                    }
232                    (lhs, rhs) => Some(*lhs $op *rhs),
233                }
234            }
235        }
236    }
237}
238
239impl_convert_float!(f32);
240impl_convert_float!(f64);
241
242impl_convert_int!(isize);
243impl_convert_int!(i8);
244impl_convert_int!(i16);
245impl_convert_int!(i32);
246impl_convert_int!(i64);
247impl_convert_int!(usize);
248impl_convert_int!(u8);
249impl_convert_int!(u16);
250impl_convert_int!(u32);
251impl_convert_int!(u64);
252
253checked_proxy!(CheckedRem, checked_rem, %);
254checked_proxy!(CheckedSub, checked_sub, -);
255checked_proxy!(CheckedAdd, checked_add, +);
256checked_proxy!(CheckedDiv, checked_div, /);
257checked_proxy!(CheckedMul, checked_mul, *);
258
259impl Add for Decimal {
260    type Output = Self;
261
262    fn add(self, other: Self) -> Self {
263        match (self, other) {
264            (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs + rhs),
265            (Self::NaN, _) => Self::NaN,
266            (_, Self::NaN) => Self::NaN,
267            (Self::PositiveInf, Self::NegativeInf) => Self::NaN,
268            (Self::NegativeInf, Self::PositiveInf) => Self::NaN,
269            (Self::PositiveInf, _) => Self::PositiveInf,
270            (_, Self::PositiveInf) => Self::PositiveInf,
271            (Self::NegativeInf, _) => Self::NegativeInf,
272            (_, Self::NegativeInf) => Self::NegativeInf,
273        }
274    }
275}
276
277impl Neg for Decimal {
278    type Output = Self;
279
280    fn neg(self) -> Self {
281        match self {
282            Self::Normalized(d) => Self::Normalized(-d),
283            Self::NaN => Self::NaN,
284            Self::PositiveInf => Self::NegativeInf,
285            Self::NegativeInf => Self::PositiveInf,
286        }
287    }
288}
289
290impl CheckedNeg for Decimal {
291    fn checked_neg(&self) -> Option<Self> {
292        match self {
293            Self::Normalized(d) => Some(Self::Normalized(-d)),
294            Self::NaN => Some(Self::NaN),
295            Self::PositiveInf => Some(Self::NegativeInf),
296            Self::NegativeInf => Some(Self::PositiveInf),
297        }
298    }
299}
300
301impl Rem for Decimal {
302    type Output = Self;
303
304    fn rem(self, other: Self) -> Self {
305        match (self, other) {
306            (Self::Normalized(lhs), Self::Normalized(rhs)) if !rhs.is_zero() => {
307                Self::Normalized(lhs % rhs)
308            }
309            (Self::Normalized(_), Self::Normalized(_)) => Self::NaN,
310            (Self::Normalized(lhs), Self::PositiveInf)
311                if lhs.is_sign_positive() || lhs.is_zero() =>
312            {
313                Self::Normalized(lhs)
314            }
315            (Self::Normalized(d), Self::PositiveInf) => Self::Normalized(d),
316            (Self::Normalized(lhs), Self::NegativeInf)
317                if lhs.is_sign_negative() || lhs.is_zero() =>
318            {
319                Self::Normalized(lhs)
320            }
321            (Self::Normalized(d), Self::NegativeInf) => Self::Normalized(d),
322            _ => Self::NaN,
323        }
324    }
325}
326
327impl Div for Decimal {
328    type Output = Self;
329
330    fn div(self, other: Self) -> Self {
331        match (self, other) {
332            // nan
333            (Self::NaN, _) => Self::NaN,
334            (_, Self::NaN) => Self::NaN,
335            // div by zero
336            (lhs, Self::Normalized(rhs)) if rhs.is_zero() => match lhs {
337                Self::Normalized(lhs) => {
338                    if lhs.is_sign_positive() && !lhs.is_zero() {
339                        Self::PositiveInf
340                    } else if lhs.is_sign_negative() && !lhs.is_zero() {
341                        Self::NegativeInf
342                    } else {
343                        Self::NaN
344                    }
345                }
346                Self::PositiveInf => Self::PositiveInf,
347                Self::NegativeInf => Self::NegativeInf,
348                _ => unreachable!(),
349            },
350            // div by +/-inf
351            (Self::Normalized(_), Self::PositiveInf) => Self::Normalized(RustDecimal::from(0)),
352            (_, Self::PositiveInf) => Self::NaN,
353            (Self::Normalized(_), Self::NegativeInf) => Self::Normalized(RustDecimal::from(0)),
354            (_, Self::NegativeInf) => Self::NaN,
355            // div inf
356            (Self::PositiveInf, Self::Normalized(d)) if d.is_sign_positive() => Self::PositiveInf,
357            (Self::PositiveInf, Self::Normalized(d)) if d.is_sign_negative() => Self::NegativeInf,
358            (Self::NegativeInf, Self::Normalized(d)) if d.is_sign_positive() => Self::NegativeInf,
359            (Self::NegativeInf, Self::Normalized(d)) if d.is_sign_negative() => Self::PositiveInf,
360            // normal case
361            (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs / rhs),
362            _ => unreachable!(),
363        }
364    }
365}
366
367impl Mul for Decimal {
368    type Output = Self;
369
370    fn mul(self, other: Self) -> Self {
371        match (self, other) {
372            (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs * rhs),
373            (Self::NaN, _) => Self::NaN,
374            (_, Self::NaN) => Self::NaN,
375            (Self::PositiveInf, Self::Normalized(rhs))
376                if !rhs.is_zero() && rhs.is_sign_negative() =>
377            {
378                Self::NegativeInf
379            }
380            (Self::PositiveInf, Self::Normalized(rhs))
381                if !rhs.is_zero() && rhs.is_sign_positive() =>
382            {
383                Self::PositiveInf
384            }
385            (Self::PositiveInf, Self::PositiveInf) => Self::PositiveInf,
386            (Self::PositiveInf, Self::NegativeInf) => Self::NegativeInf,
387            (Self::Normalized(lhs), Self::PositiveInf)
388                if !lhs.is_zero() && lhs.is_sign_negative() =>
389            {
390                Self::NegativeInf
391            }
392            (Self::Normalized(lhs), Self::PositiveInf)
393                if !lhs.is_zero() && lhs.is_sign_positive() =>
394            {
395                Self::PositiveInf
396            }
397            (Self::NegativeInf, Self::PositiveInf) => Self::NegativeInf,
398            (Self::NegativeInf, Self::Normalized(rhs))
399                if !rhs.is_zero() && rhs.is_sign_negative() =>
400            {
401                Self::PositiveInf
402            }
403            (Self::NegativeInf, Self::Normalized(rhs))
404                if !rhs.is_zero() && rhs.is_sign_positive() =>
405            {
406                Self::NegativeInf
407            }
408            (Self::NegativeInf, Self::NegativeInf) => Self::PositiveInf,
409            (Self::Normalized(lhs), Self::NegativeInf)
410                if !lhs.is_zero() && lhs.is_sign_negative() =>
411            {
412                Self::PositiveInf
413            }
414            (Self::Normalized(lhs), Self::NegativeInf)
415                if !lhs.is_zero() && lhs.is_sign_positive() =>
416            {
417                Self::NegativeInf
418            }
419            // 0 * {inf, nan} => nan
420            _ => Self::NaN,
421        }
422    }
423}
424
425impl Sub for Decimal {
426    type Output = Self;
427
428    fn sub(self, other: Self) -> Self {
429        match (self, other) {
430            (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs - rhs),
431            (Self::NaN, _) => Self::NaN,
432            (_, Self::NaN) => Self::NaN,
433            (Self::PositiveInf, Self::PositiveInf) => Self::NaN,
434            (Self::NegativeInf, Self::NegativeInf) => Self::NaN,
435            (Self::PositiveInf, _) => Self::PositiveInf,
436            (_, Self::PositiveInf) => Self::NegativeInf,
437            (Self::NegativeInf, _) => Self::NegativeInf,
438            (_, Self::NegativeInf) => Self::PositiveInf,
439        }
440    }
441}
442
443impl Decimal {
444    pub const MAX_PRECISION: u8 = 28;
445
446    pub fn scale(&self) -> Option<i32> {
447        let Decimal::Normalized(d) = self else {
448            return None;
449        };
450        Some(d.scale() as _)
451    }
452
453    pub fn rescale(&mut self, scale: u32) {
454        if let Normalized(a) = self {
455            a.rescale(scale);
456        }
457    }
458
459    #[must_use]
460    pub fn round_dp_ties_away(&self, dp: u32) -> Self {
461        match self {
462            Self::Normalized(d) => {
463                let new_d = d.round_dp_with_strategy(dp, RoundingStrategy::MidpointAwayFromZero);
464                Self::Normalized(new_d)
465            }
466            d => *d,
467        }
468    }
469
470    /// Round to the left of the decimal point, for example `31.5` -> `30`.
471    #[must_use]
472    pub fn round_left_ties_away(&self, left: u32) -> Option<Self> {
473        let &Self::Normalized(mut d) = self else {
474            return Some(*self);
475        };
476
477        // First, move the decimal point to the left so that we can reuse `round`. This is more
478        // efficient than division.
479        let old_scale = d.scale();
480        let new_scale = old_scale.saturating_add(left);
481        const MANTISSA_UP: i128 = 5 * 10i128.pow(Decimal::MAX_PRECISION as _);
482        let d = match new_scale.cmp(&Self::MAX_PRECISION.add(1).into()) {
483            // trivial within 28 digits
484            std::cmp::Ordering::Less => {
485                d.set_scale(new_scale).unwrap();
486                d.round_dp_with_strategy(0, RoundingStrategy::MidpointAwayFromZero)
487            }
488            // Special case: scale cannot be 29, but it may or may not be >= 0.5e+29
489            std::cmp::Ordering::Equal => (d.mantissa() / MANTISSA_UP).signum().into(),
490            // always 0 for >= 30 digits
491            std::cmp::Ordering::Greater => 0.into(),
492        };
493
494        // Then multiply back. Note that we cannot move decimal point to the right in order to get
495        // more zeros.
496        match left > Decimal::MAX_PRECISION.into() {
497            true => d.is_zero().then(|| 0.into()),
498            false => d
499                .checked_mul(RustDecimal::from_i128_with_scale(10i128.pow(left), 0))
500                .map(Self::Normalized),
501        }
502    }
503
504    #[must_use]
505    pub fn ceil(&self) -> Self {
506        match self {
507            Self::Normalized(d) => {
508                let mut d = d.ceil();
509                if d.is_zero() {
510                    d.set_sign_positive(true);
511                }
512                Self::Normalized(d)
513            }
514            d => *d,
515        }
516    }
517
518    #[must_use]
519    pub fn floor(&self) -> Self {
520        match self {
521            Self::Normalized(d) => Self::Normalized(d.floor()),
522            d => *d,
523        }
524    }
525
526    #[must_use]
527    pub fn trunc(&self) -> Self {
528        match self {
529            Self::Normalized(d) => {
530                let mut d = d.trunc();
531                if d.is_zero() {
532                    d.set_sign_positive(true);
533                }
534                Self::Normalized(d)
535            }
536            d => *d,
537        }
538    }
539
540    #[must_use]
541    pub fn round_ties_even(&self) -> Self {
542        match self {
543            Self::Normalized(d) => Self::Normalized(d.round()),
544            d => *d,
545        }
546    }
547
548    pub fn from_i128_with_scale(num: i128, scale: u32) -> Self {
549        Decimal::Normalized(RustDecimal::from_i128_with_scale(num, scale))
550    }
551
552    #[must_use]
553    pub fn normalize(&self) -> Self {
554        match self {
555            Self::Normalized(d) => Self::Normalized(d.normalize()),
556            d => *d,
557        }
558    }
559
560    pub fn unordered_serialize(&self) -> [u8; 16] {
561        // according to https://docs.rs/rust_decimal/1.18.0/src/rust_decimal/decimal.rs.html#665-684
562        // the lower 15 bits is not used, so we can use first byte to distinguish nan and inf
563        match self {
564            Self::Normalized(d) => d.serialize(),
565            Self::NaN => [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
566            Self::PositiveInf => [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
567            Self::NegativeInf => [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
568        }
569    }
570
571    pub fn unordered_deserialize(bytes: [u8; 16]) -> Self {
572        match bytes[0] {
573            0u8 => Self::Normalized(RustDecimal::deserialize(bytes)),
574            1u8 => Self::NaN,
575            2u8 => Self::PositiveInf,
576            3u8 => Self::NegativeInf,
577            _ => unreachable!(),
578        }
579    }
580
581    pub fn abs(&self) -> Self {
582        match self {
583            Self::Normalized(d) => {
584                if d.is_sign_negative() {
585                    Self::Normalized(-d)
586                } else {
587                    Self::Normalized(*d)
588                }
589            }
590            Self::NaN => Self::NaN,
591            Self::PositiveInf => Self::PositiveInf,
592            Self::NegativeInf => Self::PositiveInf,
593        }
594    }
595
596    pub fn sign(&self) -> Self {
597        match self {
598            Self::NaN => Self::NaN,
599            _ => match self.cmp(&0.into()) {
600                std::cmp::Ordering::Less => (-1).into(),
601                std::cmp::Ordering::Equal => 0.into(),
602                std::cmp::Ordering::Greater => 1.into(),
603            },
604        }
605    }
606
607    pub fn checked_exp(&self) -> Option<Decimal> {
608        match self {
609            Self::Normalized(d) => d.checked_exp().map(Self::Normalized),
610            Self::NaN => Some(Self::NaN),
611            Self::PositiveInf => Some(Self::PositiveInf),
612            Self::NegativeInf => Some(Self::zero()),
613        }
614    }
615
616    pub fn checked_ln(&self) -> Option<Decimal> {
617        match self {
618            Self::Normalized(d) => d.checked_ln().map(Self::Normalized),
619            Self::NaN => Some(Self::NaN),
620            Self::PositiveInf => Some(Self::PositiveInf),
621            Self::NegativeInf => None,
622        }
623    }
624
625    pub fn checked_log10(&self) -> Option<Decimal> {
626        match self {
627            Self::Normalized(d) => d.checked_log10().map(Self::Normalized),
628            Self::NaN => Some(Self::NaN),
629            Self::PositiveInf => Some(Self::PositiveInf),
630            Self::NegativeInf => None,
631        }
632    }
633
634    pub fn checked_powd(&self, rhs: &Self) -> Result<Self, PowError> {
635        use std::cmp::Ordering;
636
637        match (self, rhs) {
638            // A. Handle `nan`, where `1 ^ nan == 1` and `nan ^ 0 == 1`
639            (Decimal::NaN, Decimal::NaN)
640            | (Decimal::PositiveInf, Decimal::NaN)
641            | (Decimal::NegativeInf, Decimal::NaN)
642            | (Decimal::NaN, Decimal::PositiveInf)
643            | (Decimal::NaN, Decimal::NegativeInf) => Ok(Self::NaN),
644            (Normalized(lhs), Decimal::NaN) => match lhs.is_one() {
645                true => Ok(1.into()),
646                false => Ok(Self::NaN),
647            },
648            (Decimal::NaN, Normalized(rhs)) => match rhs.is_zero() {
649                true => Ok(1.into()),
650                false => Ok(Self::NaN),
651            },
652
653            // B. Handle `b ^ inf`
654            (Normalized(lhs), Decimal::PositiveInf) => match lhs.abs().cmp(&1.into()) {
655                Ordering::Greater => Ok(Self::PositiveInf),
656                Ordering::Equal => Ok(1.into()),
657                Ordering::Less => Ok(0.into()),
658            },
659            // Simply special case of `abs(b) > 1`.
660            // Also consistent with `inf ^ p` and `-inf ^ p` below where p is not fractional or odd.
661            (Decimal::PositiveInf, Decimal::PositiveInf)
662            | (Decimal::NegativeInf, Decimal::PositiveInf) => Ok(Self::PositiveInf),
663
664            // C. Handle `b ^ -inf`, which is `(1/b) ^ inf`
665            (Normalized(lhs), Decimal::NegativeInf) => match lhs.abs().cmp(&1.into()) {
666                Ordering::Greater => Ok(0.into()),
667                Ordering::Equal => Ok(1.into()),
668                Ordering::Less => match lhs.is_zero() {
669                    // Fun fact: ISO 9899 is removing this error to follow IEEE 754 2008.
670                    true => Err(PowError::ZeroNegative),
671                    false => Ok(Self::PositiveInf),
672                },
673            },
674            (Decimal::PositiveInf, Decimal::NegativeInf)
675            | (Decimal::NegativeInf, Decimal::NegativeInf) => Ok(0.into()),
676
677            // D. Handle `inf ^ p`
678            (Decimal::PositiveInf, Normalized(rhs)) => match rhs.cmp(&0.into()) {
679                Ordering::Greater => Ok(Self::PositiveInf),
680                Ordering::Equal => Ok(1.into()),
681                Ordering::Less => Ok(0.into()),
682            },
683
684            // E. Handle `-inf ^ p`. Finite `p` can be fractional, odd, or even.
685            (Decimal::NegativeInf, Normalized(rhs)) => match !rhs.fract().is_zero() {
686                // Err in PostgreSQL. No err in ISO 9899 which treats fractional as non-odd below.
687                true => Err(PowError::NegativeFract),
688                false => match (rhs.cmp(&0.into()), rhs.rem(&2.into()).abs().is_one()) {
689                    (Ordering::Greater, true) => Ok(Self::NegativeInf),
690                    (Ordering::Greater, false) => Ok(Self::PositiveInf),
691                    (Ordering::Equal, true) => unreachable!(),
692                    (Ordering::Equal, false) => Ok(1.into()),
693                    (Ordering::Less, true) => Ok(0.into()), // no `-0` in PostgreSQL decimal
694                    (Ordering::Less, false) => Ok(0.into()),
695                },
696            },
697
698            // F. Finite numbers
699            (Normalized(lhs), Normalized(rhs)) => {
700                if lhs.is_zero() && rhs < &0.into() {
701                    return Err(PowError::ZeroNegative);
702                }
703                if lhs < &0.into() && !rhs.fract().is_zero() {
704                    return Err(PowError::NegativeFract);
705                }
706                match lhs.checked_powd(*rhs) {
707                    Some(d) => Ok(Self::Normalized(d)),
708                    None => Err(PowError::Overflow),
709                }
710            }
711        }
712    }
713}
714
715pub enum PowError {
716    ZeroNegative,
717    NegativeFract,
718    Overflow,
719}
720
721impl From<Decimal> for memcomparable::Decimal {
722    fn from(d: Decimal) -> Self {
723        match d {
724            Decimal::Normalized(d) => Self::Normalized(d),
725            Decimal::PositiveInf => Self::Inf,
726            Decimal::NegativeInf => Self::NegInf,
727            Decimal::NaN => Self::NaN,
728        }
729    }
730}
731
732impl From<memcomparable::Decimal> for Decimal {
733    fn from(d: memcomparable::Decimal) -> Self {
734        match d {
735            memcomparable::Decimal::Normalized(d) => Self::Normalized(d),
736            memcomparable::Decimal::Inf => Self::PositiveInf,
737            memcomparable::Decimal::NegInf => Self::NegativeInf,
738            memcomparable::Decimal::NaN => Self::NaN,
739        }
740    }
741}
742
743impl Default for Decimal {
744    fn default() -> Self {
745        Self::Normalized(RustDecimal::default())
746    }
747}
748
749impl FromStr for Decimal {
750    type Err = Error;
751
752    fn from_str(s: &str) -> Result<Self, Self::Err> {
753        match s.to_ascii_lowercase().as_str() {
754            "nan" => Ok(Decimal::NaN),
755            "inf" | "+inf" | "infinity" | "+infinity" => Ok(Decimal::PositiveInf),
756            "-inf" | "-infinity" => Ok(Decimal::NegativeInf),
757            s => RustDecimal::from_str(s).map(Decimal::Normalized),
758        }
759    }
760}
761
762impl Zero for Decimal {
763    fn zero() -> Self {
764        Self::Normalized(RustDecimal::zero())
765    }
766
767    fn is_zero(&self) -> bool {
768        if let Self::Normalized(d) = self {
769            d.is_zero()
770        } else {
771            false
772        }
773    }
774}
775
776impl One for Decimal {
777    fn one() -> Self {
778        Self::Normalized(RustDecimal::one())
779    }
780}
781
782impl Num for Decimal {
783    type FromStrRadixErr = Error;
784
785    fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
786        if str.eq_ignore_ascii_case("inf") || str.eq_ignore_ascii_case("infinity") {
787            Ok(Self::PositiveInf)
788        } else if str.eq_ignore_ascii_case("-inf") || str.eq_ignore_ascii_case("-infinity") {
789            Ok(Self::NegativeInf)
790        } else if str.eq_ignore_ascii_case("nan") {
791            Ok(Self::NaN)
792        } else {
793            RustDecimal::from_str_radix(str, radix).map(Decimal::Normalized)
794        }
795    }
796}
797
798impl From<RustDecimal> for Decimal {
799    fn from(d: RustDecimal) -> Self {
800        Self::Normalized(d)
801    }
802}
803
804#[cfg(test)]
805mod tests {
806    use itertools::Itertools as _;
807    use risingwave_common_estimate_size::EstimateSize;
808
809    use super::*;
810    use crate::util::iter_util::ZipEqFast;
811
812    fn check(lhs: f32, rhs: f32) -> bool {
813        if lhs.is_nan() && rhs.is_nan() {
814            true
815        } else if lhs.is_infinite() && rhs.is_infinite() {
816            if lhs.is_sign_positive() && rhs.is_sign_positive() {
817                true
818            } else {
819                lhs.is_sign_negative() && rhs.is_sign_negative()
820            }
821        } else if lhs.is_finite() && rhs.is_finite() {
822            lhs == rhs
823        } else {
824            false
825        }
826    }
827
828    #[test]
829    fn check_op_with_float() {
830        let decimals = [
831            Decimal::NaN,
832            Decimal::PositiveInf,
833            Decimal::NegativeInf,
834            Decimal::try_from(1.0).unwrap(),
835            Decimal::try_from(-1.0).unwrap(),
836            Decimal::try_from(0.0).unwrap(),
837        ];
838        let floats = [
839            f32::NAN,
840            f32::INFINITY,
841            f32::NEG_INFINITY,
842            1.0f32,
843            -1.0f32,
844            0.0f32,
845        ];
846        for (d_lhs, f_lhs) in decimals.iter().zip_eq_fast(floats.iter()) {
847            for (d_rhs, f_rhs) in decimals.iter().zip_eq_fast(floats.iter()) {
848                assert!(check((*d_lhs + *d_rhs).try_into().unwrap(), f_lhs + f_rhs));
849                assert!(check((*d_lhs - *d_rhs).try_into().unwrap(), f_lhs - f_rhs));
850                assert!(check((*d_lhs * *d_rhs).try_into().unwrap(), f_lhs * f_rhs));
851                assert!(check((*d_lhs / *d_rhs).try_into().unwrap(), f_lhs / f_rhs));
852                assert!(check((*d_lhs % *d_rhs).try_into().unwrap(), f_lhs % f_rhs));
853            }
854        }
855    }
856
857    #[test]
858    fn basic_test() {
859        assert_eq!(Decimal::from_str("nan").unwrap(), Decimal::NaN,);
860        assert_eq!(Decimal::from_str("NaN").unwrap(), Decimal::NaN,);
861        assert_eq!(Decimal::from_str("NAN").unwrap(), Decimal::NaN,);
862        assert_eq!(Decimal::from_str("nAn").unwrap(), Decimal::NaN,);
863        assert_eq!(Decimal::from_str("nAN").unwrap(), Decimal::NaN,);
864        assert_eq!(Decimal::from_str("Nan").unwrap(), Decimal::NaN,);
865        assert_eq!(Decimal::from_str("NAn").unwrap(), Decimal::NaN,);
866
867        assert_eq!(Decimal::from_str("inf").unwrap(), Decimal::PositiveInf,);
868        assert_eq!(Decimal::from_str("INF").unwrap(), Decimal::PositiveInf,);
869        assert_eq!(Decimal::from_str("iNF").unwrap(), Decimal::PositiveInf,);
870        assert_eq!(Decimal::from_str("inF").unwrap(), Decimal::PositiveInf,);
871        assert_eq!(Decimal::from_str("InF").unwrap(), Decimal::PositiveInf,);
872        assert_eq!(Decimal::from_str("INf").unwrap(), Decimal::PositiveInf,);
873        assert_eq!(Decimal::from_str("+inf").unwrap(), Decimal::PositiveInf,);
874        assert_eq!(Decimal::from_str("+INF").unwrap(), Decimal::PositiveInf,);
875        assert_eq!(Decimal::from_str("+Inf").unwrap(), Decimal::PositiveInf,);
876        assert_eq!(Decimal::from_str("+iNF").unwrap(), Decimal::PositiveInf,);
877        assert_eq!(Decimal::from_str("+inF").unwrap(), Decimal::PositiveInf,);
878        assert_eq!(Decimal::from_str("+InF").unwrap(), Decimal::PositiveInf,);
879        assert_eq!(Decimal::from_str("+INf").unwrap(), Decimal::PositiveInf,);
880        assert_eq!(Decimal::from_str("inFINity").unwrap(), Decimal::PositiveInf,);
881        assert_eq!(
882            Decimal::from_str("+infiNIty").unwrap(),
883            Decimal::PositiveInf,
884        );
885
886        assert_eq!(Decimal::from_str("-inf").unwrap(), Decimal::NegativeInf,);
887        assert_eq!(Decimal::from_str("-INF").unwrap(), Decimal::NegativeInf,);
888        assert_eq!(Decimal::from_str("-Inf").unwrap(), Decimal::NegativeInf,);
889        assert_eq!(Decimal::from_str("-iNF").unwrap(), Decimal::NegativeInf,);
890        assert_eq!(Decimal::from_str("-inF").unwrap(), Decimal::NegativeInf,);
891        assert_eq!(Decimal::from_str("-InF").unwrap(), Decimal::NegativeInf,);
892        assert_eq!(Decimal::from_str("-INf").unwrap(), Decimal::NegativeInf,);
893        assert_eq!(
894            Decimal::from_str("-INfinity").unwrap(),
895            Decimal::NegativeInf,
896        );
897
898        assert_eq!(
899            Decimal::try_from(10.0).unwrap() / Decimal::PositiveInf,
900            Decimal::try_from(0.0).unwrap(),
901        );
902        assert_eq!(
903            Decimal::try_from(f32::INFINITY).unwrap(),
904            Decimal::PositiveInf
905        );
906        assert_eq!(Decimal::try_from(f64::NAN).unwrap(), Decimal::NaN);
907        assert_eq!(
908            Decimal::try_from(f64::INFINITY).unwrap(),
909            Decimal::PositiveInf
910        );
911        assert_eq!(
912            Decimal::unordered_deserialize(Decimal::try_from(1.234).unwrap().unordered_serialize()),
913            Decimal::try_from(1.234).unwrap(),
914        );
915        assert_eq!(
916            Decimal::unordered_deserialize(Decimal::from(1u8).unordered_serialize()),
917            Decimal::from(1u8),
918        );
919        assert_eq!(
920            Decimal::unordered_deserialize(Decimal::from(1i8).unordered_serialize()),
921            Decimal::from(1i8),
922        );
923        assert_eq!(
924            Decimal::unordered_deserialize(Decimal::from(1u16).unordered_serialize()),
925            Decimal::from(1u16),
926        );
927        assert_eq!(
928            Decimal::unordered_deserialize(Decimal::from(1i16).unordered_serialize()),
929            Decimal::from(1i16),
930        );
931        assert_eq!(
932            Decimal::unordered_deserialize(Decimal::from(1u32).unordered_serialize()),
933            Decimal::from(1u32),
934        );
935        assert_eq!(
936            Decimal::unordered_deserialize(Decimal::from(1i32).unordered_serialize()),
937            Decimal::from(1i32),
938        );
939        assert_eq!(
940            Decimal::unordered_deserialize(
941                Decimal::try_from(f64::NAN).unwrap().unordered_serialize()
942            ),
943            Decimal::try_from(f64::NAN).unwrap(),
944        );
945        assert_eq!(
946            Decimal::unordered_deserialize(
947                Decimal::try_from(f64::INFINITY)
948                    .unwrap()
949                    .unordered_serialize()
950            ),
951            Decimal::try_from(f64::INFINITY).unwrap(),
952        );
953        assert_eq!(u8::try_from(Decimal::from(1u8)).unwrap(), 1,);
954        assert_eq!(i8::try_from(Decimal::from(1i8)).unwrap(), 1,);
955        assert_eq!(u16::try_from(Decimal::from(1u16)).unwrap(), 1,);
956        assert_eq!(i16::try_from(Decimal::from(1i16)).unwrap(), 1,);
957        assert_eq!(u32::try_from(Decimal::from(1u32)).unwrap(), 1,);
958        assert_eq!(i32::try_from(Decimal::from(1i32)).unwrap(), 1,);
959        assert_eq!(u64::try_from(Decimal::from(1u64)).unwrap(), 1,);
960        assert_eq!(i64::try_from(Decimal::from(1i64)).unwrap(), 1,);
961    }
962
963    #[test]
964    fn test_order() {
965        let ordered = ["-inf", "-1", "0.00", "0.5", "2", "10", "inf", "nan"]
966            .iter()
967            .map(|s| Decimal::from_str(s).unwrap())
968            .collect_vec();
969        for i in 1..ordered.len() {
970            assert!(ordered[i - 1] < ordered[i]);
971            assert!(
972                memcomparable::Decimal::from(ordered[i - 1])
973                    < memcomparable::Decimal::from(ordered[i])
974            );
975        }
976    }
977
978    #[test]
979    fn test_decimal_estimate_size() {
980        let decimal = Decimal::NegativeInf;
981        assert_eq!(decimal.estimated_size(), 20);
982
983        let decimal = Decimal::Normalized(RustDecimal::try_from(1.0).unwrap());
984        assert_eq!(decimal.estimated_size(), 20);
985    }
986}