1use risingwave_common::cast::{parse_bytes_hex, parse_bytes_traditional};
16use risingwave_expr::{ExprError, Result, function};
17use thiserror_ext::AsReport;
18
19const PARSE_BASE64_INVALID_END: &str = "invalid base64 end sequence";
20const PARSE_BASE64_INVALID_PADDING: &str = "unexpected \"=\" while decoding base64 sequence";
21const PARSE_BASE64_ALPHABET: &[u8] =
22 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
23const PARSE_BASE64_IGNORE_BYTES: [u8; 4] = [0x0D, 0x0A, 0x20, 0x09];
24const PARSE_BASE64_ALPHABET_DECODE_TABLE: [u8; 123] = [
26 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
27 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
28 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x3E, 0x7f, 0x7f, 0x7f, 0x3F,
29 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
30 0x7f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
31 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
32 0x7f, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
33 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33,
34];
35
36#[function("encode(bytea, varchar) -> varchar")]
37pub fn encode(data: &[u8], format: &str, writer: &mut impl std::fmt::Write) -> Result<()> {
38 match format {
39 "base64" => {
40 encode_bytes_base64(data, writer)?;
41 }
42 "hex" => {
43 writer.write_str(&hex::encode(data)).unwrap();
44 }
45 "escape" => {
46 encode_bytes_escape(data, writer).unwrap();
47 }
48 _ => {
49 return Err(ExprError::InvalidParam {
50 name: "format",
51 reason: format!("unrecognized encoding: \"{}\"", format).into(),
52 });
53 }
54 }
55 Ok(())
56}
57
58#[function("decode(varchar, varchar) -> bytea")]
59pub fn decode(data: &str, format: &str, writer: &mut impl std::io::Write) -> Result<()> {
60 match format {
61 "base64" => parse_bytes_base64(data, writer),
62 "hex" => parse_bytes_hex(data, writer).map_err(|err| ExprError::Parse(err.into())),
63 "escape" => {
64 parse_bytes_traditional(data, writer).map_err(|err| ExprError::Parse(err.into()))
65 }
66 _ => Err(ExprError::InvalidParam {
67 name: "format",
68 reason: format!("unrecognized encoding: \"{}\"", format).into(),
69 }),
70 }
71}
72
73enum CharacterSet {
74 Utf8,
75}
76
77impl CharacterSet {
78 fn recognize(encoding: &str) -> Result<Self> {
79 match encoding.to_uppercase().as_str() {
80 "UTF8" | "UTF-8" => Ok(Self::Utf8),
81 _ => Err(ExprError::InvalidParam {
82 name: "encoding",
83 reason: format!("unrecognized encoding: \"{}\"", encoding).into(),
84 }),
85 }
86 }
87}
88
89#[function("convert_from(bytea, varchar) -> varchar")]
90pub fn convert_from(
91 data: &[u8],
92 src_encoding: &str,
93 writer: &mut impl std::fmt::Write,
94) -> Result<()> {
95 match CharacterSet::recognize(src_encoding)? {
96 CharacterSet::Utf8 => {
97 let text = String::from_utf8(data.to_vec()).map_err(|e| ExprError::InvalidParam {
98 name: "data",
99 reason: e.to_report_string().into(),
100 })?;
101 writer.write_str(&text).unwrap();
102 Ok(())
103 }
104 }
105}
106
107#[function("convert_to(varchar, varchar) -> bytea")]
108pub fn convert_to(
109 string: &str,
110 dest_encoding: &str,
111 writer: &mut impl std::io::Write,
112) -> Result<()> {
113 match CharacterSet::recognize(dest_encoding)? {
114 CharacterSet::Utf8 => {
115 writer.write_all(string.as_bytes()).unwrap();
116 Ok(())
117 }
118 }
119}
120
121fn encode_bytes_base64(data: &[u8], writer: &mut impl std::fmt::Write) -> Result<()> {
124 let mut idx: usize = 0;
125 let len = data.len();
126 let mut written = 0;
127 while idx + 2 < len {
128 let i1 = (data[idx] >> 2) & 0b00111111;
129 let i2 = ((data[idx] & 0b00000011) << 4) | ((data[idx + 1] >> 4) & 0b00001111);
130 let i3 = ((data[idx + 1] & 0b00001111) << 2) | ((data[idx + 2] >> 6) & 0b00000011);
131 let i4 = data[idx + 2] & 0b00111111;
132 writer
133 .write_char(PARSE_BASE64_ALPHABET[usize::from(i1)].into())
134 .unwrap();
135 writer
136 .write_char(PARSE_BASE64_ALPHABET[usize::from(i2)].into())
137 .unwrap();
138 writer
139 .write_char(PARSE_BASE64_ALPHABET[usize::from(i3)].into())
140 .unwrap();
141 writer
142 .write_char(PARSE_BASE64_ALPHABET[usize::from(i4)].into())
143 .unwrap();
144
145 written += 4;
146 idx += 3;
147 if written % 76 == 0 {
148 writer.write_char('\n').unwrap();
149 }
150 }
151
152 if idx + 2 == len {
153 let i1 = (data[idx] >> 2) & 0b00111111;
154 let i2 = ((data[idx] & 0b00000011) << 4) | ((data[idx + 1] >> 4) & 0b00001111);
155 let i3 = (data[idx + 1] & 0b00001111) << 2;
156 writer
157 .write_char(PARSE_BASE64_ALPHABET[usize::from(i1)].into())
158 .unwrap();
159 writer
160 .write_char(PARSE_BASE64_ALPHABET[usize::from(i2)].into())
161 .unwrap();
162 writer
163 .write_char(PARSE_BASE64_ALPHABET[usize::from(i3)].into())
164 .unwrap();
165 writer.write_char('=').unwrap();
166 } else if idx + 1 == len {
167 let i1 = (data[idx] >> 2) & 0b00111111;
168 let i2 = (data[idx] & 0b00000011) << 4;
169 writer
170 .write_char(PARSE_BASE64_ALPHABET[usize::from(i1)].into())
171 .unwrap();
172 writer
173 .write_char(PARSE_BASE64_ALPHABET[usize::from(i2)].into())
174 .unwrap();
175 writer.write_char('=').unwrap();
176 writer.write_char('=').unwrap();
177 }
178 Ok(())
179}
180
181fn parse_bytes_base64(data: &str, writer: &mut impl std::io::Write) -> Result<()> {
185 let data_bytes = data.as_bytes();
186
187 let mut idx: usize = 0;
188 while idx < data.len() {
189 match (
190 next(&mut idx, data_bytes),
191 next(&mut idx, data_bytes),
192 next(&mut idx, data_bytes),
193 next(&mut idx, data_bytes),
194 ) {
195 (None, None, None, None) => return Ok(()),
196 (Some(d1), Some(d2), Some(b'='), Some(b'=')) => {
197 let s1 = alphabet_decode(d1)?;
198 let s2 = alphabet_decode(d2)?;
199 writer.write_all(&[s1 << 2 | s2 >> 4]).unwrap();
200 }
201 (Some(d1), Some(d2), Some(d3), Some(b'=')) => {
202 let s1 = alphabet_decode(d1)?;
203 let s2 = alphabet_decode(d2)?;
204 let s3 = alphabet_decode(d3)?;
205 writer
206 .write_all(&[s1 << 2 | s2 >> 4, s2 << 4 | s3 >> 2])
207 .unwrap();
208 }
209 (Some(b'='), _, _, _) => {
210 return Err(ExprError::Parse(PARSE_BASE64_INVALID_PADDING.into()));
211 }
212 (Some(d1), Some(b'='), _, _) => {
213 alphabet_decode(d1)?;
214 return Err(ExprError::Parse(PARSE_BASE64_INVALID_PADDING.into()));
215 }
216 (Some(d1), Some(d2), Some(b'='), _) => {
217 alphabet_decode(d1)?;
218 alphabet_decode(d2)?;
219 return Err(ExprError::Parse(PARSE_BASE64_INVALID_PADDING.into()));
220 }
221 (Some(d1), Some(d2), Some(d3), Some(d4)) => {
222 let s1 = alphabet_decode(d1)?;
223 let s2 = alphabet_decode(d2)?;
224 let s3 = alphabet_decode(d3)?;
225 let s4 = alphabet_decode(d4)?;
226 writer
227 .write_all(&[s1 << 2 | s2 >> 4, s2 << 4 | s3 >> 2, s3 << 6 | s4])
228 .unwrap();
229 }
230 (Some(d1), None, None, None) => {
231 alphabet_decode(d1)?;
232 return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
233 }
234 (Some(d1), Some(d2), None, None) => {
235 alphabet_decode(d1)?;
236 alphabet_decode(d2)?;
237 return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
238 }
239 (Some(d1), Some(d2), Some(d3), None) => {
240 alphabet_decode(d1)?;
241 alphabet_decode(d2)?;
242 alphabet_decode(d3)?;
243 return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
244 }
245 _ => {
246 return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
247 }
248 }
249 }
250 Ok(())
251}
252
253#[inline]
254fn alphabet_decode(d: u8) -> Result<u8> {
255 if d > 0x7A {
256 Err(ExprError::Parse(
257 format!(
258 "invalid symbol \"{}\" while decoding base64 sequence",
259 std::char::from_u32(d as u32).unwrap()
260 )
261 .into(),
262 ))
263 } else {
264 let p = PARSE_BASE64_ALPHABET_DECODE_TABLE[d as usize];
265 if p == 0x7f {
266 Err(ExprError::Parse(
267 format!(
268 "invalid symbol \"{}\" while decoding base64 sequence",
269 std::char::from_u32(d as u32).unwrap()
270 )
271 .into(),
272 ))
273 } else {
274 Ok(p)
275 }
276 }
277}
278
279#[inline]
280fn next(idx: &mut usize, data: &[u8]) -> Option<u8> {
281 while *idx < data.len() && PARSE_BASE64_IGNORE_BYTES.contains(&data[*idx]) {
282 *idx += 1;
283 }
284 if *idx < data.len() {
285 let d1 = data[*idx];
286 *idx += 1;
287 Some(d1)
288 } else {
289 None
290 }
291}
292
293fn encode_bytes_escape(data: &[u8], writer: &mut impl std::fmt::Write) -> std::fmt::Result {
297 for b in data {
298 match b {
299 b'\0' | (b'\x80'..=b'\xff') => {
300 write!(writer, "\\{:03o}", b).unwrap();
301 }
302 b'\\' => writer.write_str("\\\\")?,
303 _ => writer.write_char((*b).into())?,
304 }
305 }
306 Ok(())
307}
308
309#[cfg(test)]
310mod tests {
311 use super::{decode, encode};
312
313 #[test]
314 fn test_encdec() {
315 let cases = [
316 (r#"ABCDE"#.as_bytes(), "base64", r#"QUJDREU="#.as_bytes()),
317 (r#"\""#.as_bytes(), "escape", r#"\\""#.as_bytes()),
318 (b"\x00\x40\x41\x42\xff", "escape", r"\000@AB\377".as_bytes()),
319 (
320 "aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeefffffff".as_bytes(),
321 "base64",
322 "YWFhYWFhYWFhYWJiYmJiYmJiYmJjY2NjY2NjY2NjZGRkZGRkZGRkZGVlZWVlZWVlZWVmZmZmZmZm\n"
323 .as_bytes(),
324 ),
325 (
326 "aabbccddee".as_bytes(),
327 "hex",
328 "61616262636364646565".as_bytes(),
329 ),
330 ];
331
332 for (ori, format, encoded) in cases {
333 let mut w = String::new();
334 assert!(encode(ori, format, &mut w).is_ok());
335 println!("{}", w);
336 assert_eq!(w.as_bytes(), encoded);
337 let mut res = Vec::new();
338 decode(w.as_str(), format, &mut res).unwrap();
339 assert_eq!(ori, &res);
340 }
341 }
342}