1use std::fmt::Write;
16
17use risingwave_common::cast::{parse_bytes_hex, parse_bytes_traditional};
18use risingwave_expr::{ExprError, Result, function};
19use thiserror_ext::AsReport;
20
21const PARSE_BASE64_INVALID_END: &str = "invalid base64 end sequence";
22const PARSE_BASE64_INVALID_PADDING: &str = "unexpected \"=\" while decoding base64 sequence";
23const PARSE_BASE64_ALPHABET: &[u8] =
24 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
25const PARSE_BASE64_IGNORE_BYTES: [u8; 4] = [0x0D, 0x0A, 0x20, 0x09];
26const PARSE_BASE64_ALPHABET_DECODE_TABLE: [u8; 123] = [
28 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
29 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
30 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x3E, 0x7f, 0x7f, 0x7f, 0x3F,
31 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
32 0x7f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
33 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
34 0x7f, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
35 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33,
36];
37
38#[function("encode(bytea, varchar) -> varchar")]
39pub fn encode(data: &[u8], format: &str, writer: &mut impl Write) -> Result<()> {
40 match format {
41 "base64" => {
42 encode_bytes_base64(data, writer)?;
43 }
44 "hex" => {
45 writer.write_str(&hex::encode(data)).unwrap();
46 }
47 "escape" => {
48 encode_bytes_escape(data, writer).unwrap();
49 }
50 _ => {
51 return Err(ExprError::InvalidParam {
52 name: "format",
53 reason: format!("unrecognized encoding: \"{}\"", format).into(),
54 });
55 }
56 }
57 Ok(())
58}
59
60#[function("decode(varchar, varchar) -> bytea")]
61pub fn decode(data: &str, format: &str) -> Result<Box<[u8]>> {
62 match format {
63 "base64" => Ok(parse_bytes_base64(data)?.into()),
64 "hex" => Ok(parse_bytes_hex(data)
65 .map_err(|err| ExprError::Parse(err.into()))?
66 .into()),
67 "escape" => Ok(parse_bytes_traditional(data)
68 .map_err(|err| ExprError::Parse(err.into()))?
69 .into()),
70 _ => Err(ExprError::InvalidParam {
71 name: "format",
72 reason: format!("unrecognized encoding: \"{}\"", format).into(),
73 }),
74 }
75}
76
77enum CharacterSet {
78 Utf8,
79}
80
81impl CharacterSet {
82 fn recognize(encoding: &str) -> Result<Self> {
83 match encoding.to_uppercase().as_str() {
84 "UTF8" | "UTF-8" => Ok(Self::Utf8),
85 _ => Err(ExprError::InvalidParam {
86 name: "encoding",
87 reason: format!("unrecognized encoding: \"{}\"", encoding).into(),
88 }),
89 }
90 }
91}
92
93#[function("convert_from(bytea, varchar) -> varchar")]
94pub fn convert_from(data: &[u8], src_encoding: &str, writer: &mut impl Write) -> Result<()> {
95 match CharacterSet::recognize(src_encoding)? {
96 CharacterSet::Utf8 => {
97 let text = String::from_utf8(data.to_vec()).map_err(|e| ExprError::InvalidParam {
98 name: "data",
99 reason: e.to_report_string().into(),
100 })?;
101 writer.write_str(&text).unwrap();
102 Ok(())
103 }
104 }
105}
106
107#[function("convert_to(varchar, varchar) -> bytea")]
108pub fn convert_to(string: &str, dest_encoding: &str) -> Result<Box<[u8]>> {
109 match CharacterSet::recognize(dest_encoding)? {
110 CharacterSet::Utf8 => Ok(string.as_bytes().into()),
111 }
112}
113
114fn encode_bytes_base64(data: &[u8], writer: &mut impl Write) -> Result<()> {
117 let mut idx: usize = 0;
118 let len = data.len();
119 let mut written = 0;
120 while idx + 2 < len {
121 let i1 = (data[idx] >> 2) & 0b00111111;
122 let i2 = ((data[idx] & 0b00000011) << 4) | ((data[idx + 1] >> 4) & 0b00001111);
123 let i3 = ((data[idx + 1] & 0b00001111) << 2) | ((data[idx + 2] >> 6) & 0b00000011);
124 let i4 = data[idx + 2] & 0b00111111;
125 writer
126 .write_char(PARSE_BASE64_ALPHABET[usize::from(i1)].into())
127 .unwrap();
128 writer
129 .write_char(PARSE_BASE64_ALPHABET[usize::from(i2)].into())
130 .unwrap();
131 writer
132 .write_char(PARSE_BASE64_ALPHABET[usize::from(i3)].into())
133 .unwrap();
134 writer
135 .write_char(PARSE_BASE64_ALPHABET[usize::from(i4)].into())
136 .unwrap();
137
138 written += 4;
139 idx += 3;
140 if written % 76 == 0 {
141 writer.write_char('\n').unwrap();
142 }
143 }
144
145 if idx + 2 == len {
146 let i1 = (data[idx] >> 2) & 0b00111111;
147 let i2 = ((data[idx] & 0b00000011) << 4) | ((data[idx + 1] >> 4) & 0b00001111);
148 let i3 = (data[idx + 1] & 0b00001111) << 2;
149 writer
150 .write_char(PARSE_BASE64_ALPHABET[usize::from(i1)].into())
151 .unwrap();
152 writer
153 .write_char(PARSE_BASE64_ALPHABET[usize::from(i2)].into())
154 .unwrap();
155 writer
156 .write_char(PARSE_BASE64_ALPHABET[usize::from(i3)].into())
157 .unwrap();
158 writer.write_char('=').unwrap();
159 } else if idx + 1 == len {
160 let i1 = (data[idx] >> 2) & 0b00111111;
161 let i2 = (data[idx] & 0b00000011) << 4;
162 writer
163 .write_char(PARSE_BASE64_ALPHABET[usize::from(i1)].into())
164 .unwrap();
165 writer
166 .write_char(PARSE_BASE64_ALPHABET[usize::from(i2)].into())
167 .unwrap();
168 writer.write_char('=').unwrap();
169 writer.write_char('=').unwrap();
170 }
171 Ok(())
172}
173
174fn parse_bytes_base64(data: &str) -> Result<Vec<u8>> {
178 let mut out = Vec::new();
179 let data_bytes = data.as_bytes();
180
181 let mut idx: usize = 0;
182 while idx < data.len() {
183 match (
184 next(&mut idx, data_bytes),
185 next(&mut idx, data_bytes),
186 next(&mut idx, data_bytes),
187 next(&mut idx, data_bytes),
188 ) {
189 (None, None, None, None) => return Ok(out),
190 (Some(d1), Some(d2), Some(b'='), Some(b'=')) => {
191 let s1 = alphabet_decode(d1)?;
192 let s2 = alphabet_decode(d2)?;
193 out.push(s1 << 2 | s2 >> 4);
194 }
195 (Some(d1), Some(d2), Some(d3), Some(b'=')) => {
196 let s1 = alphabet_decode(d1)?;
197 let s2 = alphabet_decode(d2)?;
198 let s3 = alphabet_decode(d3)?;
199 out.push(s1 << 2 | s2 >> 4);
200 out.push(s2 << 4 | s3 >> 2);
201 }
202 (Some(b'='), _, _, _) => {
203 return Err(ExprError::Parse(PARSE_BASE64_INVALID_PADDING.into()));
204 }
205 (Some(d1), Some(b'='), _, _) => {
206 alphabet_decode(d1)?;
207 return Err(ExprError::Parse(PARSE_BASE64_INVALID_PADDING.into()));
208 }
209 (Some(d1), Some(d2), Some(b'='), _) => {
210 alphabet_decode(d1)?;
211 alphabet_decode(d2)?;
212 return Err(ExprError::Parse(PARSE_BASE64_INVALID_PADDING.into()));
213 }
214 (Some(d1), Some(d2), Some(d3), Some(d4)) => {
215 let s1 = alphabet_decode(d1)?;
216 let s2 = alphabet_decode(d2)?;
217 let s3 = alphabet_decode(d3)?;
218 let s4 = alphabet_decode(d4)?;
219 out.push(s1 << 2 | s2 >> 4);
220 out.push(s2 << 4 | s3 >> 2);
221 out.push(s3 << 6 | s4);
222 }
223 (Some(d1), None, None, None) => {
224 alphabet_decode(d1)?;
225 return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
226 }
227 (Some(d1), Some(d2), None, None) => {
228 alphabet_decode(d1)?;
229 alphabet_decode(d2)?;
230 return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
231 }
232 (Some(d1), Some(d2), Some(d3), None) => {
233 alphabet_decode(d1)?;
234 alphabet_decode(d2)?;
235 alphabet_decode(d3)?;
236 return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
237 }
238 _ => {
239 return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
240 }
241 }
242 }
243 Ok(out)
244}
245
246#[inline]
247fn alphabet_decode(d: u8) -> Result<u8> {
248 if d > 0x7A {
249 Err(ExprError::Parse(
250 format!(
251 "invalid symbol \"{}\" while decoding base64 sequence",
252 std::char::from_u32(d as u32).unwrap()
253 )
254 .into(),
255 ))
256 } else {
257 let p = PARSE_BASE64_ALPHABET_DECODE_TABLE[d as usize];
258 if p == 0x7f {
259 Err(ExprError::Parse(
260 format!(
261 "invalid symbol \"{}\" while decoding base64 sequence",
262 std::char::from_u32(d as u32).unwrap()
263 )
264 .into(),
265 ))
266 } else {
267 Ok(p)
268 }
269 }
270}
271
272#[inline]
273fn next(idx: &mut usize, data: &[u8]) -> Option<u8> {
274 while *idx < data.len() && PARSE_BASE64_IGNORE_BYTES.contains(&data[*idx]) {
275 *idx += 1;
276 }
277 if *idx < data.len() {
278 let d1 = data[*idx];
279 *idx += 1;
280 Some(d1)
281 } else {
282 None
283 }
284}
285
286fn encode_bytes_escape(data: &[u8], writer: &mut impl Write) -> std::fmt::Result {
290 for b in data {
291 match b {
292 b'\0' | (b'\x80'..=b'\xff') => {
293 write!(writer, "\\{:03o}", b).unwrap();
294 }
295 b'\\' => writer.write_str("\\\\")?,
296 _ => writer.write_char((*b).into())?,
297 }
298 }
299 Ok(())
300}
301
302#[cfg(test)]
303mod tests {
304 use super::{decode, encode};
305
306 #[test]
307 fn test_encdec() {
308 let cases = [
309 (r#"ABCDE"#.as_bytes(), "base64", r#"QUJDREU="#.as_bytes()),
310 (r#"\""#.as_bytes(), "escape", r#"\\""#.as_bytes()),
311 (b"\x00\x40\x41\x42\xff", "escape", r"\000@AB\377".as_bytes()),
312 (
313 "aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeefffffff".as_bytes(),
314 "base64",
315 "YWFhYWFhYWFhYWJiYmJiYmJiYmJjY2NjY2NjY2NjZGRkZGRkZGRkZGVlZWVlZWVlZWVmZmZmZmZm\n"
316 .as_bytes(),
317 ),
318 (
319 "aabbccddee".as_bytes(),
320 "hex",
321 "61616262636364646565".as_bytes(),
322 ),
323 ];
324
325 for (ori, format, encoded) in cases {
326 let mut w = String::new();
327 assert!(encode(ori, format, &mut w).is_ok());
328 println!("{}", w);
329 assert_eq!(w.as_bytes(), encoded);
330 let res = decode(w.as_str(), format).unwrap();
331 assert_eq!(ori, res.as_ref());
332 }
333 }
334}