risingwave_expr_impl/scalar/
bytea_bits.rs1use risingwave_expr::{ExprError, Result, function};
16
17#[function("get_bit(bytea, int8) -> int4")]
28pub fn get_bit(bytes: &[u8], n: i64) -> Result<i32> {
29 let max_sz = (bytes.len() * 8) as i64;
30 if n < 0 || n >= max_sz {
31 return Err(ExprError::InvalidParam {
32 name: "get_bit",
33 reason: format!("index {} out of valid range, 0..{}", n, max_sz - 1).into(),
34 });
35 }
36 let index = n / 8;
37 let byte = bytes[index as usize];
38 Ok(((byte >> (n % 8)) & 1) as i32)
39}
40
41#[function("set_bit(bytea, int8, int4) -> bytea")]
52pub fn set_bit(bytes: &[u8], n: i64, value: i32, writer: &mut impl std::io::Write) -> Result<()> {
53 let max_sz = (bytes.len() * 8) as i64;
54 if n < 0 || n >= max_sz {
55 return Err(ExprError::InvalidParam {
56 name: "set_bit",
57 reason: format!("index {} out of valid range, 0..{}", n, max_sz - 1).into(),
58 });
59 }
60
61 if value != 0 && value != 1 {
62 return Err(ExprError::InvalidParam {
63 name: "set_bit",
64 reason: format!("value {} is invalid, new bit must be 0 or 1", value).into(),
65 });
66 }
67
68 let index = (n / 8) as usize;
69 let bit_pos = (n % 8) as u32;
70 let orig = bytes[index];
71 let mask = 1u8 << bit_pos;
72 let new_byte = if value != 0 {
73 orig | mask
74 } else {
75 orig & !mask
76 };
77
78 if index > 0 {
79 writer.write_all(&bytes[..index]).unwrap();
80 }
81 writer.write_all(&[new_byte]).unwrap();
82 if index + 1 < bytes.len() {
83 writer.write_all(&bytes[index + 1..]).unwrap();
84 }
85 Ok(())
86}
87
88#[function("get_byte(bytea, int4) -> int4")]
99pub fn get_byte(bytes: &[u8], n: i32) -> Result<i32> {
100 let max_sz = bytes.len() as i32;
101 if n < 0 || n >= max_sz {
102 return Err(ExprError::InvalidParam {
103 name: "get_byte",
104 reason: format!("index {} out of valid range, 0..{}", n, max_sz - 1).into(),
105 });
106 }
107 Ok(bytes[n as usize].into())
108}
109
110#[function("set_byte(bytea, int4, int4) -> bytea")]
121pub fn set_byte(bytes: &[u8], n: i32, value: i32, writer: &mut impl std::io::Write) -> Result<()> {
122 let max_sz = bytes.len() as i32;
123 if n < 0 || n >= max_sz {
124 return Err(ExprError::InvalidParam {
125 name: "set_byte",
126 reason: format!("index {} out of valid range, 0..{}", n, max_sz - 1).into(),
127 });
128 }
129
130 let index = n as usize;
131 if index > 0 {
132 writer.write_all(&bytes[..index]).unwrap();
133 }
134 writer.write_all(&[value as u8]).unwrap();
135 if index + 1 < bytes.len() {
136 writer.write_all(&bytes[index + 1..]).unwrap();
137 }
138 Ok(())
139}
140
141#[function("bit_count(bytea) -> int8")]
152pub fn bit_count(bytes: &[u8]) -> i64 {
153 let mut ans = 0;
154 for byte in bytes {
155 ans += byte.count_ones();
156 }
157 ans.into()
158}
159
160#[function("reverse(bytea) -> bytea")]
171pub fn reverse_bytea(bytes: &[u8], writer: &mut impl std::io::Write) {
172 for byte in bytes.iter().rev() {
173 writer.write_all(&[*byte]).unwrap();
174 }
175}