risingwave_expr_impl/scalar/
bitwise_op.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use std::any::type_name;
15use std::fmt::Debug;
16use std::ops::{BitAnd, BitOr, BitXor, Not};
17
18use num_traits::{CheckedShl, CheckedShr};
19use risingwave_expr::{ExprError, Result, function};
20
21// Conscious decision for shl and shr is made here to diverge from PostgreSQL.
22// If overflow happens, instead of truncated to zero, we return overflow error as this is
23// undefined behaviour. If the RHS is negative, instead of having an unexpected answer, we return an
24// error. If PG had clearly defined behavior rather than relying on UB of C, we would follow it even
25// when it is different from rust std.
26#[function("bitwise_shift_left(int2, int2) -> int2")]
27#[function("bitwise_shift_left(int2, int4) -> int2")]
28#[function("bitwise_shift_left(int4, int2) -> int4")]
29#[function("bitwise_shift_left(int4, int4) -> int4")]
30#[function("bitwise_shift_left(int8, int2) -> int8")]
31#[function("bitwise_shift_left(int8, int4) -> int8")]
32pub fn general_shl<T1, T2>(l: T1, r: T2) -> Result<T1>
33where
34    T1: CheckedShl + Debug,
35    T2: TryInto<u32> + Debug,
36{
37    general_shift(l, r, |a, b| {
38        a.checked_shl(b).ok_or(ExprError::NumericOutOfRange)
39    })
40}
41
42#[function("bitwise_shift_right(int2, int2) -> int2")]
43#[function("bitwise_shift_right(int2, int4) -> int2")]
44#[function("bitwise_shift_right(int4, int2) -> int4")]
45#[function("bitwise_shift_right(int4, int4) -> int4")]
46#[function("bitwise_shift_right(int8, int2) -> int8")]
47#[function("bitwise_shift_right(int8, int4) -> int8")]
48pub fn general_shr<T1, T2>(l: T1, r: T2) -> Result<T1>
49where
50    T1: CheckedShr + Debug,
51    T2: TryInto<u32> + Debug,
52{
53    general_shift(l, r, |a, b| {
54        a.checked_shr(b).ok_or(ExprError::NumericOutOfRange)
55    })
56}
57
58#[inline(always)]
59fn general_shift<T1, T2, F>(l: T1, r: T2, atm: F) -> Result<T1>
60where
61    T1: Debug,
62    T2: TryInto<u32> + Debug,
63    F: FnOnce(T1, u32) -> Result<T1>,
64{
65    // TODO: We need to improve the error message
66    let r: u32 = r
67        .try_into()
68        .map_err(|_| ExprError::CastOutOfRange(type_name::<u32>()))?;
69    atm(l, r)
70}
71
72#[function("bitwise_and(*int, *int) -> auto")]
73pub fn general_bitand<T1, T2, T3>(l: T1, r: T2) -> T3
74where
75    T1: Into<T3> + Debug,
76    T2: Into<T3> + Debug,
77    T3: BitAnd<Output = T3>,
78{
79    l.into() & r.into()
80}
81
82#[function("bitwise_or(*int, *int) -> auto")]
83pub fn general_bitor<T1, T2, T3>(l: T1, r: T2) -> T3
84where
85    T1: Into<T3> + Debug,
86    T2: Into<T3> + Debug,
87    T3: BitOr<Output = T3>,
88{
89    l.into() | r.into()
90}
91
92#[function("bitwise_xor(*int, *int) -> auto")]
93pub fn general_bitxor<T1, T2, T3>(l: T1, r: T2) -> T3
94where
95    T1: Into<T3> + Debug,
96    T2: Into<T3> + Debug,
97    T3: BitXor<Output = T3>,
98{
99    l.into() ^ r.into()
100}
101
102#[function("bitwise_not(*int) -> auto")]
103pub fn general_bitnot<T1: Not<Output = T1>>(expr: T1) -> T1 {
104    !expr
105}
106
107#[cfg(test)]
108mod tests {
109    use std::assert_matches::assert_matches;
110
111    use super::*;
112
113    #[test]
114    fn test_bitwise() {
115        // check the boundary
116        assert_eq!(general_shl::<i32, i32>(1i32, 0i32).unwrap(), 1i32);
117        assert_eq!(general_shl::<i64, i32>(1i64, 31i32).unwrap(), 2147483648i64);
118        assert_matches!(
119            general_shl::<i32, i32>(1i32, 32i32).unwrap_err(),
120            ExprError::NumericOutOfRange,
121        );
122        assert_eq!(
123            general_shr::<i64, i32>(-2147483648i64, 31i32).unwrap(),
124            -1i64
125        );
126        assert_eq!(general_shr::<i64, i32>(1i64, 0i32).unwrap(), 1i64);
127        // truth table
128        assert_eq!(
129            general_bitand::<u32, u32, u64>(0b0011u32, 0b0101u32),
130            0b1u64
131        );
132        assert_eq!(
133            general_bitor::<u32, u32, u64>(0b0011u32, 0b0101u32),
134            0b0111u64
135        );
136        assert_eq!(
137            general_bitxor::<u32, u32, u64>(0b0011u32, 0b0101u32),
138            0b0110u64
139        );
140        assert_eq!(general_bitnot::<i32>(0b01i32), -2i32);
141    }
142}