risingwave_common/types/
decimal.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::io::{Cursor, Read, Write};
17use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
18
19use byteorder::{BigEndian, ReadBytesExt};
20use bytes::{BufMut, BytesMut};
21use num_traits::{
22    CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Zero,
23};
24use postgres_types::{FromSql, IsNull, ToSql, Type, accepts, to_sql_checked};
25use risingwave_common_estimate_size::ZeroHeapSize;
26use rust_decimal::prelude::FromStr;
27use rust_decimal::{Decimal as RustDecimal, Error, MathematicalOps as _, RoundingStrategy};
28
29use super::DataType;
30use super::to_text::ToText;
31use crate::array::ArrayResult;
32use crate::types::Decimal::Normalized;
33use crate::types::ordered_float::OrderedFloat;
34
35#[derive(Debug, Copy, parse_display::Display, Clone, PartialEq, Hash, Eq, Ord, PartialOrd)]
36pub enum Decimal {
37    #[display("-Infinity")]
38    NegativeInf,
39    #[display("{0}")]
40    Normalized(RustDecimal),
41    #[display("Infinity")]
42    PositiveInf,
43    #[display("NaN")]
44    NaN,
45}
46
47impl ZeroHeapSize for Decimal {}
48
49impl ToText for Decimal {
50    fn write<W: std::fmt::Write>(&self, f: &mut W) -> std::fmt::Result {
51        write!(f, "{self}")
52    }
53
54    fn write_with_type<W: std::fmt::Write>(&self, ty: &DataType, f: &mut W) -> std::fmt::Result {
55        match ty {
56            DataType::Decimal => self.write(f),
57            _ => unreachable!(),
58        }
59    }
60}
61
62impl Decimal {
63    /// Used by `PrimitiveArray` to serialize the array to protobuf.
64    pub fn to_protobuf(self, output: &mut impl Write) -> ArrayResult<usize> {
65        let buf = self.unordered_serialize();
66        output.write_all(&buf)?;
67        Ok(buf.len())
68    }
69
70    /// Used by `DecimalValueReader` to deserialize the array from protobuf.
71    pub fn from_protobuf(input: &mut impl Read) -> ArrayResult<Self> {
72        let mut buf = [0u8; 16];
73        input.read_exact(&mut buf)?;
74        Ok(Self::unordered_deserialize(buf))
75    }
76
77    pub fn from_scientific(value: &str) -> Option<Self> {
78        let decimal = RustDecimal::from_scientific(value).ok()?;
79        Some(Normalized(decimal))
80    }
81
82    pub fn from_str_radix(s: &str, radix: u32) -> rust_decimal::Result<Self> {
83        match s.to_ascii_lowercase().as_str() {
84            "nan" => Ok(Decimal::NaN),
85            "inf" | "+inf" | "infinity" | "+infinity" => Ok(Decimal::PositiveInf),
86            "-inf" | "-infinity" => Ok(Decimal::NegativeInf),
87            s => RustDecimal::from_str_radix(s, radix).map(Decimal::Normalized),
88        }
89    }
90}
91
92impl ToSql for Decimal {
93    accepts!(NUMERIC);
94
95    to_sql_checked!();
96
97    fn to_sql(
98        &self,
99        ty: &Type,
100        out: &mut BytesMut,
101    ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
102    where
103        Self: Sized,
104    {
105        match self {
106            Decimal::Normalized(d) => {
107                return d.to_sql(ty, out);
108            }
109            Decimal::NaN => {
110                out.reserve(8);
111                out.put_u16(0);
112                out.put_i16(0);
113                out.put_u16(0xC000);
114                out.put_i16(0);
115            }
116            Decimal::PositiveInf => {
117                out.reserve(8);
118                out.put_u16(0);
119                out.put_i16(0);
120                out.put_u16(0xD000);
121                out.put_i16(0);
122            }
123            Decimal::NegativeInf => {
124                out.reserve(8);
125                out.put_u16(0);
126                out.put_i16(0);
127                out.put_u16(0xF000);
128                out.put_i16(0);
129            }
130        }
131        Ok(IsNull::No)
132    }
133}
134
135impl<'a> FromSql<'a> for Decimal {
136    fn from_sql(
137        ty: &Type,
138        raw: &'a [u8],
139    ) -> Result<Self, Box<dyn std::error::Error + 'static + Sync + Send>> {
140        let mut rdr = Cursor::new(raw);
141        let _n_digits = rdr.read_u16::<BigEndian>()?;
142        let _weight = rdr.read_i16::<BigEndian>()?;
143        let sign = rdr.read_u16::<BigEndian>()?;
144        match sign {
145            0xC000 => Ok(Self::NaN),
146            0xD000 => Ok(Self::PositiveInf),
147            0xF000 => Ok(Self::NegativeInf),
148            _ => RustDecimal::from_sql(ty, raw).map(Self::Normalized),
149        }
150    }
151
152    fn accepts(ty: &Type) -> bool {
153        matches!(*ty, Type::NUMERIC)
154    }
155}
156
157macro_rules! impl_convert_int {
158    ($T:ty) => {
159        impl core::convert::From<$T> for Decimal {
160            #[inline]
161            fn from(t: $T) -> Self {
162                Self::Normalized(t.into())
163            }
164        }
165
166        impl core::convert::TryFrom<Decimal> for $T {
167            type Error = Error;
168
169            #[inline]
170            fn try_from(d: Decimal) -> Result<Self, Self::Error> {
171                match d.round_dp_ties_away(0) {
172                    Decimal::Normalized(d) => d.try_into(),
173                    _ => Err(Error::ConversionTo(std::any::type_name::<$T>().into())),
174                }
175            }
176        }
177    };
178}
179
180macro_rules! impl_convert_float {
181    ($T:ty) => {
182        impl core::convert::TryFrom<$T> for Decimal {
183            type Error = Error;
184
185            fn try_from(num: $T) -> Result<Self, Self::Error> {
186                match num {
187                    num if num.is_nan() => Ok(Decimal::NaN),
188                    num if num.is_infinite() && num.is_sign_positive() => Ok(Decimal::PositiveInf),
189                    num if num.is_infinite() && num.is_sign_negative() => Ok(Decimal::NegativeInf),
190                    num => num.try_into().map(Decimal::Normalized),
191                }
192            }
193        }
194        impl core::convert::TryFrom<OrderedFloat<$T>> for Decimal {
195            type Error = Error;
196
197            fn try_from(value: OrderedFloat<$T>) -> Result<Self, Self::Error> {
198                value.0.try_into()
199            }
200        }
201
202        impl core::convert::TryFrom<Decimal> for $T {
203            type Error = Error;
204
205            fn try_from(d: Decimal) -> Result<Self, Self::Error> {
206                match d {
207                    Decimal::Normalized(d) => d.try_into(),
208                    Decimal::NaN => Ok(<$T>::NAN),
209                    Decimal::PositiveInf => Ok(<$T>::INFINITY),
210                    Decimal::NegativeInf => Ok(<$T>::NEG_INFINITY),
211                }
212            }
213        }
214        impl core::convert::TryFrom<Decimal> for OrderedFloat<$T> {
215            type Error = Error;
216
217            fn try_from(d: Decimal) -> Result<Self, Self::Error> {
218                d.try_into().map(Self)
219            }
220        }
221    };
222}
223
224macro_rules! checked_proxy {
225    ($trait:ty, $func:ident, $op: tt) => {
226        impl $trait for Decimal {
227            fn $func(&self, other: &Self) -> Option<Self> {
228                match (self, other) {
229                    (Self::Normalized(lhs), Self::Normalized(rhs)) => {
230                        lhs.$func(rhs).map(Decimal::Normalized)
231                    }
232                    (lhs, rhs) => Some(*lhs $op *rhs),
233                }
234            }
235        }
236    }
237}
238
239impl_convert_float!(f32);
240impl_convert_float!(f64);
241
242impl_convert_int!(isize);
243impl_convert_int!(i8);
244impl_convert_int!(i16);
245impl_convert_int!(i32);
246impl_convert_int!(i64);
247impl_convert_int!(usize);
248impl_convert_int!(u8);
249impl_convert_int!(u16);
250impl_convert_int!(u32);
251impl_convert_int!(u64);
252
253checked_proxy!(CheckedRem, checked_rem, %);
254checked_proxy!(CheckedSub, checked_sub, -);
255checked_proxy!(CheckedAdd, checked_add, +);
256checked_proxy!(CheckedDiv, checked_div, /);
257checked_proxy!(CheckedMul, checked_mul, *);
258
259impl Add for Decimal {
260    type Output = Self;
261
262    fn add(self, other: Self) -> Self {
263        match (self, other) {
264            (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs + rhs),
265            (Self::NaN, _) => Self::NaN,
266            (_, Self::NaN) => Self::NaN,
267            (Self::PositiveInf, Self::NegativeInf) => Self::NaN,
268            (Self::NegativeInf, Self::PositiveInf) => Self::NaN,
269            (Self::PositiveInf, _) => Self::PositiveInf,
270            (_, Self::PositiveInf) => Self::PositiveInf,
271            (Self::NegativeInf, _) => Self::NegativeInf,
272            (_, Self::NegativeInf) => Self::NegativeInf,
273        }
274    }
275}
276
277impl Neg for Decimal {
278    type Output = Self;
279
280    fn neg(self) -> Self {
281        match self {
282            Self::Normalized(d) => Self::Normalized(-d),
283            Self::NaN => Self::NaN,
284            Self::PositiveInf => Self::NegativeInf,
285            Self::NegativeInf => Self::PositiveInf,
286        }
287    }
288}
289
290impl CheckedNeg for Decimal {
291    fn checked_neg(&self) -> Option<Self> {
292        match self {
293            Self::Normalized(d) => Some(Self::Normalized(-d)),
294            Self::NaN => Some(Self::NaN),
295            Self::PositiveInf => Some(Self::NegativeInf),
296            Self::NegativeInf => Some(Self::PositiveInf),
297        }
298    }
299}
300
301impl Rem for Decimal {
302    type Output = Self;
303
304    fn rem(self, other: Self) -> Self {
305        match (self, other) {
306            (Self::Normalized(lhs), Self::Normalized(rhs)) if !rhs.is_zero() => {
307                Self::Normalized(lhs % rhs)
308            }
309            (Self::Normalized(_), Self::Normalized(_)) => Self::NaN,
310            (Self::Normalized(lhs), Self::PositiveInf)
311                if lhs.is_sign_positive() || lhs.is_zero() =>
312            {
313                Self::Normalized(lhs)
314            }
315            (Self::Normalized(d), Self::PositiveInf) => Self::Normalized(d),
316            (Self::Normalized(lhs), Self::NegativeInf)
317                if lhs.is_sign_negative() || lhs.is_zero() =>
318            {
319                Self::Normalized(lhs)
320            }
321            (Self::Normalized(d), Self::NegativeInf) => Self::Normalized(d),
322            _ => Self::NaN,
323        }
324    }
325}
326
327impl Div for Decimal {
328    type Output = Self;
329
330    fn div(self, other: Self) -> Self {
331        match (self, other) {
332            // nan
333            (Self::NaN, _) => Self::NaN,
334            (_, Self::NaN) => Self::NaN,
335            // div by zero
336            (lhs, Self::Normalized(rhs)) if rhs.is_zero() => match lhs {
337                Self::Normalized(lhs) => {
338                    if lhs.is_sign_positive() && !lhs.is_zero() {
339                        Self::PositiveInf
340                    } else if lhs.is_sign_negative() && !lhs.is_zero() {
341                        Self::NegativeInf
342                    } else {
343                        Self::NaN
344                    }
345                }
346                Self::PositiveInf => Self::PositiveInf,
347                Self::NegativeInf => Self::NegativeInf,
348                _ => unreachable!(),
349            },
350            // div by +/-inf
351            (Self::Normalized(_), Self::PositiveInf) => Self::Normalized(RustDecimal::from(0)),
352            (_, Self::PositiveInf) => Self::NaN,
353            (Self::Normalized(_), Self::NegativeInf) => Self::Normalized(RustDecimal::from(0)),
354            (_, Self::NegativeInf) => Self::NaN,
355            // div inf
356            (Self::PositiveInf, Self::Normalized(d)) if d.is_sign_positive() => Self::PositiveInf,
357            (Self::PositiveInf, Self::Normalized(d)) if d.is_sign_negative() => Self::NegativeInf,
358            (Self::NegativeInf, Self::Normalized(d)) if d.is_sign_positive() => Self::NegativeInf,
359            (Self::NegativeInf, Self::Normalized(d)) if d.is_sign_negative() => Self::PositiveInf,
360            // normal case
361            (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs / rhs),
362            _ => unreachable!(),
363        }
364    }
365}
366
367impl Mul for Decimal {
368    type Output = Self;
369
370    fn mul(self, other: Self) -> Self {
371        match (self, other) {
372            (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs * rhs),
373            (Self::NaN, _) => Self::NaN,
374            (_, Self::NaN) => Self::NaN,
375            (Self::PositiveInf, Self::Normalized(rhs))
376                if !rhs.is_zero() && rhs.is_sign_negative() =>
377            {
378                Self::NegativeInf
379            }
380            (Self::PositiveInf, Self::Normalized(rhs))
381                if !rhs.is_zero() && rhs.is_sign_positive() =>
382            {
383                Self::PositiveInf
384            }
385            (Self::PositiveInf, Self::PositiveInf) => Self::PositiveInf,
386            (Self::PositiveInf, Self::NegativeInf) => Self::NegativeInf,
387            (Self::Normalized(lhs), Self::PositiveInf)
388                if !lhs.is_zero() && lhs.is_sign_negative() =>
389            {
390                Self::NegativeInf
391            }
392            (Self::Normalized(lhs), Self::PositiveInf)
393                if !lhs.is_zero() && lhs.is_sign_positive() =>
394            {
395                Self::PositiveInf
396            }
397            (Self::NegativeInf, Self::PositiveInf) => Self::NegativeInf,
398            (Self::NegativeInf, Self::Normalized(rhs))
399                if !rhs.is_zero() && rhs.is_sign_negative() =>
400            {
401                Self::PositiveInf
402            }
403            (Self::NegativeInf, Self::Normalized(rhs))
404                if !rhs.is_zero() && rhs.is_sign_positive() =>
405            {
406                Self::NegativeInf
407            }
408            (Self::NegativeInf, Self::NegativeInf) => Self::PositiveInf,
409            (Self::Normalized(lhs), Self::NegativeInf)
410                if !lhs.is_zero() && lhs.is_sign_negative() =>
411            {
412                Self::PositiveInf
413            }
414            (Self::Normalized(lhs), Self::NegativeInf)
415                if !lhs.is_zero() && lhs.is_sign_positive() =>
416            {
417                Self::NegativeInf
418            }
419            // 0 * {inf, nan} => nan
420            _ => Self::NaN,
421        }
422    }
423}
424
425impl Sub for Decimal {
426    type Output = Self;
427
428    fn sub(self, other: Self) -> Self {
429        match (self, other) {
430            (Self::Normalized(lhs), Self::Normalized(rhs)) => Self::Normalized(lhs - rhs),
431            (Self::NaN, _) => Self::NaN,
432            (_, Self::NaN) => Self::NaN,
433            (Self::PositiveInf, Self::PositiveInf) => Self::NaN,
434            (Self::NegativeInf, Self::NegativeInf) => Self::NaN,
435            (Self::PositiveInf, _) => Self::PositiveInf,
436            (_, Self::PositiveInf) => Self::NegativeInf,
437            (Self::NegativeInf, _) => Self::NegativeInf,
438            (_, Self::NegativeInf) => Self::PositiveInf,
439        }
440    }
441}
442
443impl Decimal {
444    const MAX_I128_REPR: i128 = 0x0000_0000_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF;
445    pub const MAX_PRECISION: u8 = 28;
446
447    pub fn scale(&self) -> Option<i32> {
448        let Decimal::Normalized(d) = self else {
449            return None;
450        };
451        Some(d.scale() as _)
452    }
453
454    pub fn rescale(&mut self, scale: u32) {
455        if let Normalized(a) = self {
456            a.rescale(scale);
457        }
458    }
459
460    #[must_use]
461    pub fn round_dp_ties_away(&self, dp: u32) -> Self {
462        match self {
463            Self::Normalized(d) => {
464                let new_d = d.round_dp_with_strategy(dp, RoundingStrategy::MidpointAwayFromZero);
465                Self::Normalized(new_d)
466            }
467            d => *d,
468        }
469    }
470
471    /// Round to the left of the decimal point, for example `31.5` -> `30`.
472    #[must_use]
473    pub fn round_left_ties_away(&self, left: u32) -> Option<Self> {
474        let &Self::Normalized(mut d) = self else {
475            return Some(*self);
476        };
477
478        // First, move the decimal point to the left so that we can reuse `round`. This is more
479        // efficient than division.
480        let old_scale = d.scale();
481        let new_scale = old_scale.saturating_add(left);
482        const MANTISSA_UP: i128 = 5 * 10i128.pow(Decimal::MAX_PRECISION as _);
483        let d = match new_scale.cmp(&Self::MAX_PRECISION.add(1).into()) {
484            // trivial within 28 digits
485            std::cmp::Ordering::Less => {
486                d.set_scale(new_scale).unwrap();
487                d.round_dp_with_strategy(0, RoundingStrategy::MidpointAwayFromZero)
488            }
489            // Special case: scale cannot be 29, but it may or may not be >= 0.5e+29
490            std::cmp::Ordering::Equal => (d.mantissa() / MANTISSA_UP).signum().into(),
491            // always 0 for >= 30 digits
492            std::cmp::Ordering::Greater => 0.into(),
493        };
494
495        // Then multiply back. Note that we cannot move decimal point to the right in order to get
496        // more zeros.
497        match left > Decimal::MAX_PRECISION.into() {
498            true => d.is_zero().then(|| 0.into()),
499            false => d
500                .checked_mul(RustDecimal::from_i128_with_scale(10i128.pow(left), 0))
501                .map(Self::Normalized),
502        }
503    }
504
505    #[must_use]
506    pub fn ceil(&self) -> Self {
507        match self {
508            Self::Normalized(d) => {
509                let mut d = d.ceil();
510                if d.is_zero() {
511                    d.set_sign_positive(true);
512                }
513                Self::Normalized(d)
514            }
515            d => *d,
516        }
517    }
518
519    #[must_use]
520    pub fn floor(&self) -> Self {
521        match self {
522            Self::Normalized(d) => Self::Normalized(d.floor()),
523            d => *d,
524        }
525    }
526
527    #[must_use]
528    pub fn trunc(&self) -> Self {
529        match self {
530            Self::Normalized(d) => {
531                let mut d = d.trunc();
532                if d.is_zero() {
533                    d.set_sign_positive(true);
534                }
535                Self::Normalized(d)
536            }
537            d => *d,
538        }
539    }
540
541    #[must_use]
542    pub fn round_ties_even(&self) -> Self {
543        match self {
544            Self::Normalized(d) => Self::Normalized(d.round()),
545            d => *d,
546        }
547    }
548
549    pub fn from_i128_with_scale(num: i128, scale: u32) -> Self {
550        Decimal::Normalized(RustDecimal::from_i128_with_scale(num, scale))
551    }
552
553    /// Truncate the given `num` and `scale` to fit into `Decimal`, return `None` if it cannot be
554    /// represented even after truncation.
555    pub fn truncated_i128_and_scale(mut num: i128, mut scale: u32) -> Option<Self> {
556        if num.abs() > Self::MAX_I128_REPR {
557            let digits = num.abs().ilog10() + 1;
558            let diff_scale = digits.saturating_sub(Self::MAX_PRECISION as u32);
559            if scale < diff_scale {
560                return None;
561            }
562            num /= 10i128.pow(diff_scale);
563            scale -= diff_scale;
564        }
565        if scale > Self::MAX_PRECISION as u32 {
566            let diff_scale = scale - Self::MAX_PRECISION as u32;
567            num /= 10i128.pow(diff_scale);
568            scale = Self::MAX_PRECISION as u32;
569        }
570        Some(Decimal::Normalized(
571            RustDecimal::try_from_i128_with_scale(num, scale).ok()?,
572        ))
573    }
574
575    #[must_use]
576    pub fn normalize(&self) -> Self {
577        match self {
578            Self::Normalized(d) => Self::Normalized(d.normalize()),
579            d => *d,
580        }
581    }
582
583    pub fn unordered_serialize(&self) -> [u8; 16] {
584        // according to https://docs.rs/rust_decimal/1.18.0/src/rust_decimal/decimal.rs.html#665-684
585        // the lower 15 bits is not used, so we can use first byte to distinguish nan and inf
586        match self {
587            Self::Normalized(d) => d.serialize(),
588            Self::NaN => [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
589            Self::PositiveInf => [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
590            Self::NegativeInf => [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
591        }
592    }
593
594    pub fn unordered_deserialize(bytes: [u8; 16]) -> Self {
595        match bytes[0] {
596            0u8 => Self::Normalized(RustDecimal::deserialize(bytes)),
597            1u8 => Self::NaN,
598            2u8 => Self::PositiveInf,
599            3u8 => Self::NegativeInf,
600            _ => unreachable!(),
601        }
602    }
603
604    pub fn abs(&self) -> Self {
605        match self {
606            Self::Normalized(d) => {
607                if d.is_sign_negative() {
608                    Self::Normalized(-d)
609                } else {
610                    Self::Normalized(*d)
611                }
612            }
613            Self::NaN => Self::NaN,
614            Self::PositiveInf => Self::PositiveInf,
615            Self::NegativeInf => Self::PositiveInf,
616        }
617    }
618
619    pub fn sign(&self) -> Self {
620        match self {
621            Self::NaN => Self::NaN,
622            _ => match self.cmp(&0.into()) {
623                std::cmp::Ordering::Less => (-1).into(),
624                std::cmp::Ordering::Equal => 0.into(),
625                std::cmp::Ordering::Greater => 1.into(),
626            },
627        }
628    }
629
630    pub fn checked_exp(&self) -> Option<Decimal> {
631        match self {
632            Self::Normalized(d) => d.checked_exp().map(Self::Normalized),
633            Self::NaN => Some(Self::NaN),
634            Self::PositiveInf => Some(Self::PositiveInf),
635            Self::NegativeInf => Some(Self::zero()),
636        }
637    }
638
639    pub fn checked_ln(&self) -> Option<Decimal> {
640        match self {
641            Self::Normalized(d) => d.checked_ln().map(Self::Normalized),
642            Self::NaN => Some(Self::NaN),
643            Self::PositiveInf => Some(Self::PositiveInf),
644            Self::NegativeInf => None,
645        }
646    }
647
648    pub fn checked_log10(&self) -> Option<Decimal> {
649        match self {
650            Self::Normalized(d) => d.checked_log10().map(Self::Normalized),
651            Self::NaN => Some(Self::NaN),
652            Self::PositiveInf => Some(Self::PositiveInf),
653            Self::NegativeInf => None,
654        }
655    }
656
657    pub fn checked_powd(&self, rhs: &Self) -> Result<Self, PowError> {
658        use std::cmp::Ordering;
659
660        match (self, rhs) {
661            // A. Handle `nan`, where `1 ^ nan == 1` and `nan ^ 0 == 1`
662            (Decimal::NaN, Decimal::NaN)
663            | (Decimal::PositiveInf, Decimal::NaN)
664            | (Decimal::NegativeInf, Decimal::NaN)
665            | (Decimal::NaN, Decimal::PositiveInf)
666            | (Decimal::NaN, Decimal::NegativeInf) => Ok(Self::NaN),
667            (Normalized(lhs), Decimal::NaN) => match lhs.is_one() {
668                true => Ok(1.into()),
669                false => Ok(Self::NaN),
670            },
671            (Decimal::NaN, Normalized(rhs)) => match rhs.is_zero() {
672                true => Ok(1.into()),
673                false => Ok(Self::NaN),
674            },
675
676            // B. Handle `b ^ inf`
677            (Normalized(lhs), Decimal::PositiveInf) => match lhs.abs().cmp(&1.into()) {
678                Ordering::Greater => Ok(Self::PositiveInf),
679                Ordering::Equal => Ok(1.into()),
680                Ordering::Less => Ok(0.into()),
681            },
682            // Simply special case of `abs(b) > 1`.
683            // Also consistent with `inf ^ p` and `-inf ^ p` below where p is not fractional or odd.
684            (Decimal::PositiveInf, Decimal::PositiveInf)
685            | (Decimal::NegativeInf, Decimal::PositiveInf) => Ok(Self::PositiveInf),
686
687            // C. Handle `b ^ -inf`, which is `(1/b) ^ inf`
688            (Normalized(lhs), Decimal::NegativeInf) => match lhs.abs().cmp(&1.into()) {
689                Ordering::Greater => Ok(0.into()),
690                Ordering::Equal => Ok(1.into()),
691                Ordering::Less => match lhs.is_zero() {
692                    // Fun fact: ISO 9899 is removing this error to follow IEEE 754 2008.
693                    true => Err(PowError::ZeroNegative),
694                    false => Ok(Self::PositiveInf),
695                },
696            },
697            (Decimal::PositiveInf, Decimal::NegativeInf)
698            | (Decimal::NegativeInf, Decimal::NegativeInf) => Ok(0.into()),
699
700            // D. Handle `inf ^ p`
701            (Decimal::PositiveInf, Normalized(rhs)) => match rhs.cmp(&0.into()) {
702                Ordering::Greater => Ok(Self::PositiveInf),
703                Ordering::Equal => Ok(1.into()),
704                Ordering::Less => Ok(0.into()),
705            },
706
707            // E. Handle `-inf ^ p`. Finite `p` can be fractional, odd, or even.
708            (Decimal::NegativeInf, Normalized(rhs)) => match !rhs.fract().is_zero() {
709                // Err in PostgreSQL. No err in ISO 9899 which treats fractional as non-odd below.
710                true => Err(PowError::NegativeFract),
711                false => match (rhs.cmp(&0.into()), rhs.rem(&2.into()).abs().is_one()) {
712                    (Ordering::Greater, true) => Ok(Self::NegativeInf),
713                    (Ordering::Greater, false) => Ok(Self::PositiveInf),
714                    (Ordering::Equal, true) => unreachable!(),
715                    (Ordering::Equal, false) => Ok(1.into()),
716                    (Ordering::Less, true) => Ok(0.into()), // no `-0` in PostgreSQL decimal
717                    (Ordering::Less, false) => Ok(0.into()),
718                },
719            },
720
721            // F. Finite numbers
722            (Normalized(lhs), Normalized(rhs)) => {
723                if lhs.is_zero() && rhs < &0.into() {
724                    return Err(PowError::ZeroNegative);
725                }
726                if lhs < &0.into() && !rhs.fract().is_zero() {
727                    return Err(PowError::NegativeFract);
728                }
729                match lhs.checked_powd(*rhs) {
730                    Some(d) => Ok(Self::Normalized(d)),
731                    None => Err(PowError::Overflow),
732                }
733            }
734        }
735    }
736}
737
738pub enum PowError {
739    ZeroNegative,
740    NegativeFract,
741    Overflow,
742}
743
744impl From<Decimal> for memcomparable::Decimal {
745    fn from(d: Decimal) -> Self {
746        match d {
747            Decimal::Normalized(d) => Self::Normalized(d),
748            Decimal::PositiveInf => Self::Inf,
749            Decimal::NegativeInf => Self::NegInf,
750            Decimal::NaN => Self::NaN,
751        }
752    }
753}
754
755impl From<memcomparable::Decimal> for Decimal {
756    fn from(d: memcomparable::Decimal) -> Self {
757        match d {
758            memcomparable::Decimal::Normalized(d) => Self::Normalized(d),
759            memcomparable::Decimal::Inf => Self::PositiveInf,
760            memcomparable::Decimal::NegInf => Self::NegativeInf,
761            memcomparable::Decimal::NaN => Self::NaN,
762        }
763    }
764}
765
766impl Default for Decimal {
767    fn default() -> Self {
768        Self::Normalized(RustDecimal::default())
769    }
770}
771
772impl FromStr for Decimal {
773    type Err = Error;
774
775    fn from_str(s: &str) -> Result<Self, Self::Err> {
776        match s.to_ascii_lowercase().as_str() {
777            "nan" => Ok(Decimal::NaN),
778            "inf" | "+inf" | "infinity" | "+infinity" => Ok(Decimal::PositiveInf),
779            "-inf" | "-infinity" => Ok(Decimal::NegativeInf),
780            s => RustDecimal::from_str(s)
781                .or_else(|_| RustDecimal::from_scientific(s))
782                .map(Decimal::Normalized),
783        }
784    }
785}
786
787impl Zero for Decimal {
788    fn zero() -> Self {
789        Self::Normalized(RustDecimal::zero())
790    }
791
792    fn is_zero(&self) -> bool {
793        if let Self::Normalized(d) = self {
794            d.is_zero()
795        } else {
796            false
797        }
798    }
799}
800
801impl One for Decimal {
802    fn one() -> Self {
803        Self::Normalized(RustDecimal::one())
804    }
805}
806
807impl Num for Decimal {
808    type FromStrRadixErr = Error;
809
810    fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
811        if str.eq_ignore_ascii_case("inf") || str.eq_ignore_ascii_case("infinity") {
812            Ok(Self::PositiveInf)
813        } else if str.eq_ignore_ascii_case("-inf") || str.eq_ignore_ascii_case("-infinity") {
814            Ok(Self::NegativeInf)
815        } else if str.eq_ignore_ascii_case("nan") {
816            Ok(Self::NaN)
817        } else {
818            RustDecimal::from_str_radix(str, radix).map(Decimal::Normalized)
819        }
820    }
821}
822
823impl From<RustDecimal> for Decimal {
824    fn from(d: RustDecimal) -> Self {
825        Self::Normalized(d)
826    }
827}
828
829#[cfg(test)]
830mod tests {
831    use itertools::Itertools as _;
832    use risingwave_common_estimate_size::EstimateSize;
833
834    use super::*;
835    use crate::util::iter_util::ZipEqFast;
836
837    fn check(lhs: f32, rhs: f32) -> bool {
838        if lhs.is_nan() && rhs.is_nan() {
839            true
840        } else if lhs.is_infinite() && rhs.is_infinite() {
841            if lhs.is_sign_positive() && rhs.is_sign_positive() {
842                true
843            } else {
844                lhs.is_sign_negative() && rhs.is_sign_negative()
845            }
846        } else if lhs.is_finite() && rhs.is_finite() {
847            lhs == rhs
848        } else {
849            false
850        }
851    }
852
853    #[test]
854    fn check_op_with_float() {
855        let decimals = [
856            Decimal::NaN,
857            Decimal::PositiveInf,
858            Decimal::NegativeInf,
859            Decimal::try_from(1.0).unwrap(),
860            Decimal::try_from(-1.0).unwrap(),
861            Decimal::try_from(0.0).unwrap(),
862        ];
863        let floats = [
864            f32::NAN,
865            f32::INFINITY,
866            f32::NEG_INFINITY,
867            1.0f32,
868            -1.0f32,
869            0.0f32,
870        ];
871        for (d_lhs, f_lhs) in decimals.iter().zip_eq_fast(floats.iter()) {
872            for (d_rhs, f_rhs) in decimals.iter().zip_eq_fast(floats.iter()) {
873                assert!(check((*d_lhs + *d_rhs).try_into().unwrap(), f_lhs + f_rhs));
874                assert!(check((*d_lhs - *d_rhs).try_into().unwrap(), f_lhs - f_rhs));
875                assert!(check((*d_lhs * *d_rhs).try_into().unwrap(), f_lhs * f_rhs));
876                assert!(check((*d_lhs / *d_rhs).try_into().unwrap(), f_lhs / f_rhs));
877                assert!(check((*d_lhs % *d_rhs).try_into().unwrap(), f_lhs % f_rhs));
878            }
879        }
880    }
881
882    #[test]
883    fn basic_test() {
884        assert_eq!(Decimal::from_str("nan").unwrap(), Decimal::NaN,);
885        assert_eq!(Decimal::from_str("NaN").unwrap(), Decimal::NaN,);
886        assert_eq!(Decimal::from_str("NAN").unwrap(), Decimal::NaN,);
887        assert_eq!(Decimal::from_str("nAn").unwrap(), Decimal::NaN,);
888        assert_eq!(Decimal::from_str("nAN").unwrap(), Decimal::NaN,);
889        assert_eq!(Decimal::from_str("Nan").unwrap(), Decimal::NaN,);
890        assert_eq!(Decimal::from_str("NAn").unwrap(), Decimal::NaN,);
891
892        assert_eq!(Decimal::from_str("inf").unwrap(), Decimal::PositiveInf,);
893        assert_eq!(Decimal::from_str("INF").unwrap(), Decimal::PositiveInf,);
894        assert_eq!(Decimal::from_str("iNF").unwrap(), Decimal::PositiveInf,);
895        assert_eq!(Decimal::from_str("inF").unwrap(), Decimal::PositiveInf,);
896        assert_eq!(Decimal::from_str("InF").unwrap(), Decimal::PositiveInf,);
897        assert_eq!(Decimal::from_str("INf").unwrap(), Decimal::PositiveInf,);
898        assert_eq!(Decimal::from_str("+inf").unwrap(), Decimal::PositiveInf,);
899        assert_eq!(Decimal::from_str("+INF").unwrap(), Decimal::PositiveInf,);
900        assert_eq!(Decimal::from_str("+Inf").unwrap(), Decimal::PositiveInf,);
901        assert_eq!(Decimal::from_str("+iNF").unwrap(), Decimal::PositiveInf,);
902        assert_eq!(Decimal::from_str("+inF").unwrap(), Decimal::PositiveInf,);
903        assert_eq!(Decimal::from_str("+InF").unwrap(), Decimal::PositiveInf,);
904        assert_eq!(Decimal::from_str("+INf").unwrap(), Decimal::PositiveInf,);
905        assert_eq!(Decimal::from_str("inFINity").unwrap(), Decimal::PositiveInf,);
906        assert_eq!(
907            Decimal::from_str("+infiNIty").unwrap(),
908            Decimal::PositiveInf,
909        );
910
911        assert_eq!(Decimal::from_str("-inf").unwrap(), Decimal::NegativeInf,);
912        assert_eq!(Decimal::from_str("-INF").unwrap(), Decimal::NegativeInf,);
913        assert_eq!(Decimal::from_str("-Inf").unwrap(), Decimal::NegativeInf,);
914        assert_eq!(Decimal::from_str("-iNF").unwrap(), Decimal::NegativeInf,);
915        assert_eq!(Decimal::from_str("-inF").unwrap(), Decimal::NegativeInf,);
916        assert_eq!(Decimal::from_str("-InF").unwrap(), Decimal::NegativeInf,);
917        assert_eq!(Decimal::from_str("-INf").unwrap(), Decimal::NegativeInf,);
918        assert_eq!(
919            Decimal::from_str("-INfinity").unwrap(),
920            Decimal::NegativeInf,
921        );
922
923        assert_eq!(
924            Decimal::try_from(10.0).unwrap() / Decimal::PositiveInf,
925            Decimal::try_from(0.0).unwrap(),
926        );
927        assert_eq!(
928            Decimal::try_from(f32::INFINITY).unwrap(),
929            Decimal::PositiveInf
930        );
931        assert_eq!(Decimal::try_from(f64::NAN).unwrap(), Decimal::NaN);
932        assert_eq!(
933            Decimal::try_from(f64::INFINITY).unwrap(),
934            Decimal::PositiveInf
935        );
936        assert_eq!(
937            Decimal::unordered_deserialize(Decimal::try_from(1.234).unwrap().unordered_serialize()),
938            Decimal::try_from(1.234).unwrap(),
939        );
940        assert_eq!(
941            Decimal::unordered_deserialize(Decimal::from(1u8).unordered_serialize()),
942            Decimal::from(1u8),
943        );
944        assert_eq!(
945            Decimal::unordered_deserialize(Decimal::from(1i8).unordered_serialize()),
946            Decimal::from(1i8),
947        );
948        assert_eq!(
949            Decimal::unordered_deserialize(Decimal::from(1u16).unordered_serialize()),
950            Decimal::from(1u16),
951        );
952        assert_eq!(
953            Decimal::unordered_deserialize(Decimal::from(1i16).unordered_serialize()),
954            Decimal::from(1i16),
955        );
956        assert_eq!(
957            Decimal::unordered_deserialize(Decimal::from(1u32).unordered_serialize()),
958            Decimal::from(1u32),
959        );
960        assert_eq!(
961            Decimal::unordered_deserialize(Decimal::from(1i32).unordered_serialize()),
962            Decimal::from(1i32),
963        );
964        assert_eq!(
965            Decimal::unordered_deserialize(
966                Decimal::try_from(f64::NAN).unwrap().unordered_serialize()
967            ),
968            Decimal::try_from(f64::NAN).unwrap(),
969        );
970        assert_eq!(
971            Decimal::unordered_deserialize(
972                Decimal::try_from(f64::INFINITY)
973                    .unwrap()
974                    .unordered_serialize()
975            ),
976            Decimal::try_from(f64::INFINITY).unwrap(),
977        );
978        assert_eq!(u8::try_from(Decimal::from(1u8)).unwrap(), 1,);
979        assert_eq!(i8::try_from(Decimal::from(1i8)).unwrap(), 1,);
980        assert_eq!(u16::try_from(Decimal::from(1u16)).unwrap(), 1,);
981        assert_eq!(i16::try_from(Decimal::from(1i16)).unwrap(), 1,);
982        assert_eq!(u32::try_from(Decimal::from(1u32)).unwrap(), 1,);
983        assert_eq!(i32::try_from(Decimal::from(1i32)).unwrap(), 1,);
984        assert_eq!(u64::try_from(Decimal::from(1u64)).unwrap(), 1,);
985        assert_eq!(i64::try_from(Decimal::from(1i64)).unwrap(), 1,);
986    }
987
988    #[test]
989    fn test_order() {
990        let ordered = ["-inf", "-1", "0.00", "0.5", "2", "10", "inf", "nan"]
991            .iter()
992            .map(|s| Decimal::from_str(s).unwrap())
993            .collect_vec();
994        for i in 1..ordered.len() {
995            assert!(ordered[i - 1] < ordered[i]);
996            assert!(
997                memcomparable::Decimal::from(ordered[i - 1])
998                    < memcomparable::Decimal::from(ordered[i])
999            );
1000        }
1001    }
1002
1003    #[test]
1004    fn test_decimal_estimate_size() {
1005        let decimal = Decimal::NegativeInf;
1006        assert_eq!(decimal.estimated_size(), 20);
1007
1008        let decimal = Decimal::Normalized(RustDecimal::try_from(1.0).unwrap());
1009        assert_eq!(decimal.estimated_size(), 20);
1010    }
1011}