risingwave_expr_impl/scalar/
bytea_bits.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_expr::{ExprError, Result, function};
16
17/// Extracts n'th bit from binary string.
18///
19/// # Example
20///
21/// ```slt
22/// query T
23/// SELECT get_bit('\x1234567890'::bytea, 30);
24/// ----
25/// 1
26/// ```
27#[function("get_bit(bytea, int8) -> int4")]
28pub fn get_bit(bytes: &[u8], n: i64) -> Result<i32> {
29    let max_sz = (bytes.len() * 8) as i64;
30    if n < 0 || n >= max_sz {
31        return Err(ExprError::InvalidParam {
32            name: "get_bit",
33            reason: format!("index {} out of valid range, 0..{}", n, max_sz - 1).into(),
34        });
35    }
36    let index = n / 8;
37    let byte = bytes[index as usize];
38    Ok(((byte >> (n % 8)) & 1) as i32)
39}
40
41/// Sets n'th bit in binary string to newvalue.
42///
43/// # Example
44///
45/// ```slt
46/// query T
47/// SELECT set_bit('\x1234567890'::bytea, 30, 0);
48/// ----
49/// \x1234563890
50/// ```
51#[function("set_bit(bytea, int8, int4) -> bytea")]
52pub fn set_bit(bytes: &[u8], n: i64, value: i32) -> Result<Box<[u8]>> {
53    let max_sz = (bytes.len() * 8) as i64;
54    if n < 0 || n >= max_sz {
55        return Err(ExprError::InvalidParam {
56            name: "set_bit",
57            reason: format!("index {} out of valid range, 0..{}", n, max_sz - 1).into(),
58        });
59    }
60
61    if value != 0 && value != 1 {
62        return Err(ExprError::InvalidParam {
63            name: "set_bit",
64            reason: format!("value {} is invalid, new bit must be 0 or 1", value).into(),
65        });
66    }
67
68    let mut buf = bytes.to_vec();
69    let index = (n / 8) as usize;
70    let bit_pos = (n % 8) as u8;
71
72    if value != 0 {
73        buf[index] |= 1 << bit_pos;
74    } else {
75        buf[index] &= !(1 << bit_pos);
76    }
77    Ok(buf.into_boxed_slice())
78}
79
80/// Extracts n'th byte from binary string.
81///
82/// # Example
83///
84/// ```slt
85/// query T
86/// SELECT get_byte('\x1234567890'::bytea, 4);
87/// ----
88/// 144
89/// ```
90#[function("get_byte(bytea, int4) -> int4")]
91pub fn get_byte(bytes: &[u8], n: i32) -> Result<i32> {
92    let max_sz = bytes.len() as i32;
93    if n < 0 || n >= max_sz {
94        return Err(ExprError::InvalidParam {
95            name: "get_byte",
96            reason: format!("index {} out of valid range, 0..{}", n, max_sz - 1).into(),
97        });
98    }
99    Ok(bytes[n as usize].into())
100}
101
102/// Sets n'th byte in binary string to newvalue.
103///
104/// # Example
105///
106/// ```slt
107/// query T
108/// SELECT set_byte('\x1234567890'::bytea, 4, 64);
109/// ----
110/// \x1234567840
111/// ```
112#[function("set_byte(bytea, int4, int4) -> bytea")]
113pub fn set_byte(bytes: &[u8], n: i32, value: i32) -> Result<Box<[u8]>> {
114    let max_sz = bytes.len() as i32;
115    if n < 0 || n >= max_sz {
116        return Err(ExprError::InvalidParam {
117            name: "set_byte",
118            reason: format!("index {} out of valid range, 0..{}", n, max_sz - 1).into(),
119        });
120    }
121    let mut buf = bytes.to_vec();
122    buf[n as usize] = value as u8;
123    Ok(buf.into_boxed_slice())
124}
125
126/// Returns the number of bits set in the binary string
127///
128/// # Example
129///
130/// ```slt
131/// query T
132/// SELECT bit_count('\x1234567890'::bytea);
133/// ----
134/// 15
135/// ```
136#[function("bit_count(bytea) -> int8")]
137pub fn bit_count(bytes: &[u8]) -> i64 {
138    let mut ans = 0;
139    for byte in bytes {
140        ans += byte.count_ones();
141    }
142    ans.into()
143}