risingwave_expr_impl/scalar/
bitwise_op.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::type_name;
16use std::fmt::Debug;
17use std::ops::{BitAnd, BitOr, BitXor, Not};
18
19use num_traits::{CheckedShl, CheckedShr};
20use risingwave_expr::{ExprError, Result, function};
21
22// Conscious decision for shl and shr is made here to diverge from PostgreSQL.
23// If overflow happens, instead of truncated to zero, we return overflow error as this is
24// undefined behaviour. If the RHS is negative, instead of having an unexpected answer, we return an
25// error. If PG had clearly defined behavior rather than relying on UB of C, we would follow it even
26// when it is different from rust std.
27#[function("bitwise_shift_left(int2, int2) -> int2")]
28#[function("bitwise_shift_left(int2, int4) -> int2")]
29#[function("bitwise_shift_left(int4, int2) -> int4")]
30#[function("bitwise_shift_left(int4, int4) -> int4")]
31#[function("bitwise_shift_left(int8, int2) -> int8")]
32#[function("bitwise_shift_left(int8, int4) -> int8")]
33pub fn general_shl<T1, T2>(l: T1, r: T2) -> Result<T1>
34where
35    T1: CheckedShl + Debug,
36    T2: TryInto<u32> + Debug,
37{
38    general_shift(l, r, |a, b| {
39        a.checked_shl(b).ok_or(ExprError::NumericOutOfRange)
40    })
41}
42
43#[function("bitwise_shift_right(int2, int2) -> int2")]
44#[function("bitwise_shift_right(int2, int4) -> int2")]
45#[function("bitwise_shift_right(int4, int2) -> int4")]
46#[function("bitwise_shift_right(int4, int4) -> int4")]
47#[function("bitwise_shift_right(int8, int2) -> int8")]
48#[function("bitwise_shift_right(int8, int4) -> int8")]
49pub fn general_shr<T1, T2>(l: T1, r: T2) -> Result<T1>
50where
51    T1: CheckedShr + Debug,
52    T2: TryInto<u32> + Debug,
53{
54    general_shift(l, r, |a, b| {
55        a.checked_shr(b).ok_or(ExprError::NumericOutOfRange)
56    })
57}
58
59#[inline(always)]
60fn general_shift<T1, T2, F>(l: T1, r: T2, atm: F) -> Result<T1>
61where
62    T1: Debug,
63    T2: TryInto<u32> + Debug,
64    F: FnOnce(T1, u32) -> Result<T1>,
65{
66    // TODO: We need to improve the error message
67    let r: u32 = r
68        .try_into()
69        .map_err(|_| ExprError::CastOutOfRange(type_name::<u32>()))?;
70    atm(l, r)
71}
72
73#[function("bitwise_and(*int, *int) -> auto")]
74pub fn general_bitand<T1, T2, T3>(l: T1, r: T2) -> T3
75where
76    T1: Into<T3> + Debug,
77    T2: Into<T3> + Debug,
78    T3: BitAnd<Output = T3>,
79{
80    l.into() & r.into()
81}
82
83#[function("bitwise_or(*int, *int) -> auto")]
84pub fn general_bitor<T1, T2, T3>(l: T1, r: T2) -> T3
85where
86    T1: Into<T3> + Debug,
87    T2: Into<T3> + Debug,
88    T3: BitOr<Output = T3>,
89{
90    l.into() | r.into()
91}
92
93#[function("bitwise_xor(*int, *int) -> auto")]
94pub fn general_bitxor<T1, T2, T3>(l: T1, r: T2) -> T3
95where
96    T1: Into<T3> + Debug,
97    T2: Into<T3> + Debug,
98    T3: BitXor<Output = T3>,
99{
100    l.into() ^ r.into()
101}
102
103#[function("bitwise_not(*int) -> auto")]
104pub fn general_bitnot<T1: Not<Output = T1>>(expr: T1) -> T1 {
105    !expr
106}
107
108#[cfg(test)]
109mod tests {
110    use std::assert_matches::assert_matches;
111
112    use super::*;
113
114    #[test]
115    fn test_bitwise() {
116        // check the boundary
117        assert_eq!(general_shl::<i32, i32>(1i32, 0i32).unwrap(), 1i32);
118        assert_eq!(general_shl::<i64, i32>(1i64, 31i32).unwrap(), 2147483648i64);
119        assert_matches!(
120            general_shl::<i32, i32>(1i32, 32i32).unwrap_err(),
121            ExprError::NumericOutOfRange,
122        );
123        assert_eq!(
124            general_shr::<i64, i32>(-2147483648i64, 31i32).unwrap(),
125            -1i64
126        );
127        assert_eq!(general_shr::<i64, i32>(1i64, 0i32).unwrap(), 1i64);
128        // truth table
129        assert_eq!(
130            general_bitand::<u32, u32, u64>(0b0011u32, 0b0101u32),
131            0b1u64
132        );
133        assert_eq!(
134            general_bitor::<u32, u32, u64>(0b0011u32, 0b0101u32),
135            0b0111u64
136        );
137        assert_eq!(
138            general_bitxor::<u32, u32, u64>(0b0011u32, 0b0101u32),
139            0b0110u64
140        );
141        assert_eq!(general_bitnot::<i32>(0b01i32), -2i32);
142    }
143}