risingwave_expr_impl/scalar/
bytea_bits.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_expr::{ExprError, Result, function};
16
17/// Extracts n'th bit from binary string.
18///
19/// # Example
20///
21/// ```slt
22/// query T
23/// SELECT get_bit('\x1234567890'::bytea, 30);
24/// ----
25/// 1
26/// ```
27#[function("get_bit(bytea, int8) -> int4")]
28pub fn get_bit(bytes: &[u8], n: i64) -> Result<i32> {
29    let max_sz = (bytes.len() * 8) as i64;
30    if n < 0 || n >= max_sz {
31        return Err(ExprError::InvalidParam {
32            name: "get_bit",
33            reason: format!("index {} out of valid range, 0..{}", n, max_sz - 1).into(),
34        });
35    }
36    let index = n / 8;
37    let byte = bytes[index as usize];
38    Ok(((byte >> (n % 8)) & 1) as i32)
39}
40
41/// Sets n'th bit in binary string to newvalue.
42///
43/// # Example
44///
45/// ```slt
46/// query T
47/// SELECT set_bit('\x1234567890'::bytea, 30, 0);
48/// ----
49/// \x1234563890
50/// ```
51#[function("set_bit(bytea, int8, int4) -> bytea")]
52pub fn set_bit(bytes: &[u8], n: i64, value: i32, writer: &mut impl std::io::Write) -> Result<()> {
53    let max_sz = (bytes.len() * 8) as i64;
54    if n < 0 || n >= max_sz {
55        return Err(ExprError::InvalidParam {
56            name: "set_bit",
57            reason: format!("index {} out of valid range, 0..{}", n, max_sz - 1).into(),
58        });
59    }
60
61    if value != 0 && value != 1 {
62        return Err(ExprError::InvalidParam {
63            name: "set_bit",
64            reason: format!("value {} is invalid, new bit must be 0 or 1", value).into(),
65        });
66    }
67
68    let index = (n / 8) as usize;
69    let bit_pos = (n % 8) as u32;
70    let orig = bytes[index];
71    let mask = 1u8 << bit_pos;
72    let new_byte = if value != 0 {
73        orig | mask
74    } else {
75        orig & !mask
76    };
77
78    if index > 0 {
79        writer.write_all(&bytes[..index]).unwrap();
80    }
81    writer.write_all(&[new_byte]).unwrap();
82    if index + 1 < bytes.len() {
83        writer.write_all(&bytes[index + 1..]).unwrap();
84    }
85    Ok(())
86}
87
88/// Extracts n'th byte from binary string.
89///
90/// # Example
91///
92/// ```slt
93/// query T
94/// SELECT get_byte('\x1234567890'::bytea, 4);
95/// ----
96/// 144
97/// ```
98#[function("get_byte(bytea, int4) -> int4")]
99pub fn get_byte(bytes: &[u8], n: i32) -> Result<i32> {
100    let max_sz = bytes.len() as i32;
101    if n < 0 || n >= max_sz {
102        return Err(ExprError::InvalidParam {
103            name: "get_byte",
104            reason: format!("index {} out of valid range, 0..{}", n, max_sz - 1).into(),
105        });
106    }
107    Ok(bytes[n as usize].into())
108}
109
110/// Sets n'th byte in binary string to newvalue.
111///
112/// # Example
113///
114/// ```slt
115/// query T
116/// SELECT set_byte('\x1234567890'::bytea, 4, 64);
117/// ----
118/// \x1234567840
119/// ```
120#[function("set_byte(bytea, int4, int4) -> bytea")]
121pub fn set_byte(bytes: &[u8], n: i32, value: i32, writer: &mut impl std::io::Write) -> Result<()> {
122    let max_sz = bytes.len() as i32;
123    if n < 0 || n >= max_sz {
124        return Err(ExprError::InvalidParam {
125            name: "set_byte",
126            reason: format!("index {} out of valid range, 0..{}", n, max_sz - 1).into(),
127        });
128    }
129
130    let index = n as usize;
131    if index > 0 {
132        writer.write_all(&bytes[..index]).unwrap();
133    }
134    writer.write_all(&[value as u8]).unwrap();
135    if index + 1 < bytes.len() {
136        writer.write_all(&bytes[index + 1..]).unwrap();
137    }
138    Ok(())
139}
140
141/// Returns the number of bits set in the binary string
142///
143/// # Example
144///
145/// ```slt
146/// query T
147/// SELECT bit_count('\x1234567890'::bytea);
148/// ----
149/// 15
150/// ```
151#[function("bit_count(bytea) -> int8")]
152pub fn bit_count(bytes: &[u8]) -> i64 {
153    let mut ans = 0;
154    for byte in bytes {
155        ans += byte.count_ones();
156    }
157    ans.into()
158}
159
160/// Reverses the bytes in the binary string.
161///
162/// # Example
163///
164/// ```slt
165/// query T
166/// SELECT reverse('\x1234567890'::bytea);
167/// ----
168/// \x9078563412
169/// ```
170#[function("reverse(bytea) -> bytea")]
171pub fn reverse_bytea(bytes: &[u8], writer: &mut impl std::io::Write) {
172    for byte in bytes.iter().rev() {
173        writer.write_all(&[*byte]).unwrap();
174    }
175}