risingwave_expr_impl/scalar/
encdec.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::cast::{parse_bytes_hex, parse_bytes_traditional};
16use risingwave_expr::{ExprError, Result, function};
17use thiserror_ext::AsReport;
18
19const PARSE_BASE64_INVALID_END: &str = "invalid base64 end sequence";
20const PARSE_BASE64_INVALID_PADDING: &str = "unexpected \"=\" while decoding base64 sequence";
21const PARSE_BASE64_ALPHABET: &[u8] =
22    b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
23const PARSE_BASE64_IGNORE_BYTES: [u8; 4] = [0x0D, 0x0A, 0x20, 0x09];
24// such as  'A'/0x41 -> 0  'B'/0x42 -> 1
25const PARSE_BASE64_ALPHABET_DECODE_TABLE: [u8; 123] = [
26    0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
27    0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
28    0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x3E, 0x7f, 0x7f, 0x7f, 0x3F,
29    0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
30    0x7f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
31    0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
32    0x7f, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
33    0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33,
34];
35
36#[function("encode(bytea, varchar) -> varchar")]
37pub fn encode(data: &[u8], format: &str, writer: &mut impl std::fmt::Write) -> Result<()> {
38    match format {
39        "base64" => {
40            encode_bytes_base64(data, writer)?;
41        }
42        "hex" => {
43            writer.write_str(&hex::encode(data)).unwrap();
44        }
45        "escape" => {
46            encode_bytes_escape(data, writer).unwrap();
47        }
48        _ => {
49            return Err(ExprError::InvalidParam {
50                name: "format",
51                reason: format!("unrecognized encoding: \"{}\"", format).into(),
52            });
53        }
54    }
55    Ok(())
56}
57
58#[function("decode(varchar, varchar) -> bytea")]
59pub fn decode(data: &str, format: &str, writer: &mut impl std::io::Write) -> Result<()> {
60    match format {
61        "base64" => parse_bytes_base64(data, writer),
62        "hex" => parse_bytes_hex(data, writer).map_err(|err| ExprError::Parse(err.into())),
63        "escape" => {
64            parse_bytes_traditional(data, writer).map_err(|err| ExprError::Parse(err.into()))
65        }
66        _ => Err(ExprError::InvalidParam {
67            name: "format",
68            reason: format!("unrecognized encoding: \"{}\"", format).into(),
69        }),
70    }
71}
72
73enum CharacterSet {
74    Utf8,
75}
76
77impl CharacterSet {
78    fn recognize(encoding: &str) -> Result<Self> {
79        match encoding.to_uppercase().as_str() {
80            "UTF8" | "UTF-8" => Ok(Self::Utf8),
81            _ => Err(ExprError::InvalidParam {
82                name: "encoding",
83                reason: format!("unrecognized encoding: \"{}\"", encoding).into(),
84            }),
85        }
86    }
87}
88
89#[function("convert_from(bytea, varchar) -> varchar")]
90pub fn convert_from(
91    data: &[u8],
92    src_encoding: &str,
93    writer: &mut impl std::fmt::Write,
94) -> Result<()> {
95    match CharacterSet::recognize(src_encoding)? {
96        CharacterSet::Utf8 => {
97            let text = String::from_utf8(data.to_vec()).map_err(|e| ExprError::InvalidParam {
98                name: "data",
99                reason: e.to_report_string().into(),
100            })?;
101            writer.write_str(&text).unwrap();
102            Ok(())
103        }
104    }
105}
106
107#[function("convert_to(varchar, varchar) -> bytea")]
108pub fn convert_to(
109    string: &str,
110    dest_encoding: &str,
111    writer: &mut impl std::io::Write,
112) -> Result<()> {
113    match CharacterSet::recognize(dest_encoding)? {
114        CharacterSet::Utf8 => {
115            writer.write_all(string.as_bytes()).unwrap();
116            Ok(())
117        }
118    }
119}
120
121// According to https://www.postgresql.org/docs/current/functions-binarystring.html#ENCODE-FORMAT-BASE64
122// We need to split newlines when the output length is greater than or equal to 76
123fn encode_bytes_base64(data: &[u8], writer: &mut impl std::fmt::Write) -> Result<()> {
124    let mut idx: usize = 0;
125    let len = data.len();
126    let mut written = 0;
127    while idx + 2 < len {
128        let i1 = (data[idx] >> 2) & 0b00111111;
129        let i2 = ((data[idx] & 0b00000011) << 4) | ((data[idx + 1] >> 4) & 0b00001111);
130        let i3 = ((data[idx + 1] & 0b00001111) << 2) | ((data[idx + 2] >> 6) & 0b00000011);
131        let i4 = data[idx + 2] & 0b00111111;
132        writer
133            .write_char(PARSE_BASE64_ALPHABET[usize::from(i1)].into())
134            .unwrap();
135        writer
136            .write_char(PARSE_BASE64_ALPHABET[usize::from(i2)].into())
137            .unwrap();
138        writer
139            .write_char(PARSE_BASE64_ALPHABET[usize::from(i3)].into())
140            .unwrap();
141        writer
142            .write_char(PARSE_BASE64_ALPHABET[usize::from(i4)].into())
143            .unwrap();
144
145        written += 4;
146        idx += 3;
147        if written % 76 == 0 {
148            writer.write_char('\n').unwrap();
149        }
150    }
151
152    if idx + 2 == len {
153        let i1 = (data[idx] >> 2) & 0b00111111;
154        let i2 = ((data[idx] & 0b00000011) << 4) | ((data[idx + 1] >> 4) & 0b00001111);
155        let i3 = (data[idx + 1] & 0b00001111) << 2;
156        writer
157            .write_char(PARSE_BASE64_ALPHABET[usize::from(i1)].into())
158            .unwrap();
159        writer
160            .write_char(PARSE_BASE64_ALPHABET[usize::from(i2)].into())
161            .unwrap();
162        writer
163            .write_char(PARSE_BASE64_ALPHABET[usize::from(i3)].into())
164            .unwrap();
165        writer.write_char('=').unwrap();
166    } else if idx + 1 == len {
167        let i1 = (data[idx] >> 2) & 0b00111111;
168        let i2 = (data[idx] & 0b00000011) << 4;
169        writer
170            .write_char(PARSE_BASE64_ALPHABET[usize::from(i1)].into())
171            .unwrap();
172        writer
173            .write_char(PARSE_BASE64_ALPHABET[usize::from(i2)].into())
174            .unwrap();
175        writer.write_char('=').unwrap();
176        writer.write_char('=').unwrap();
177    }
178    Ok(())
179}
180
181// According to https://www.postgresql.org/docs/current/functions-binarystring.html#ENCODE-FORMAT-BASE64
182// parse_bytes_base64 need ignores carriage-return[0x0D], newline[0x0A], space[0x20], and tab[0x09].
183// When decode is supplied invalid base64 data, including incorrect trailing padding, return error.
184fn parse_bytes_base64(data: &str, writer: &mut impl std::io::Write) -> Result<()> {
185    let data_bytes = data.as_bytes();
186
187    let mut idx: usize = 0;
188    while idx < data.len() {
189        match (
190            next(&mut idx, data_bytes),
191            next(&mut idx, data_bytes),
192            next(&mut idx, data_bytes),
193            next(&mut idx, data_bytes),
194        ) {
195            (None, None, None, None) => return Ok(()),
196            (Some(d1), Some(d2), Some(b'='), Some(b'=')) => {
197                let s1 = alphabet_decode(d1)?;
198                let s2 = alphabet_decode(d2)?;
199                writer.write_all(&[s1 << 2 | s2 >> 4]).unwrap();
200            }
201            (Some(d1), Some(d2), Some(d3), Some(b'=')) => {
202                let s1 = alphabet_decode(d1)?;
203                let s2 = alphabet_decode(d2)?;
204                let s3 = alphabet_decode(d3)?;
205                writer
206                    .write_all(&[s1 << 2 | s2 >> 4, s2 << 4 | s3 >> 2])
207                    .unwrap();
208            }
209            (Some(b'='), _, _, _) => {
210                return Err(ExprError::Parse(PARSE_BASE64_INVALID_PADDING.into()));
211            }
212            (Some(d1), Some(b'='), _, _) => {
213                alphabet_decode(d1)?;
214                return Err(ExprError::Parse(PARSE_BASE64_INVALID_PADDING.into()));
215            }
216            (Some(d1), Some(d2), Some(b'='), _) => {
217                alphabet_decode(d1)?;
218                alphabet_decode(d2)?;
219                return Err(ExprError::Parse(PARSE_BASE64_INVALID_PADDING.into()));
220            }
221            (Some(d1), Some(d2), Some(d3), Some(d4)) => {
222                let s1 = alphabet_decode(d1)?;
223                let s2 = alphabet_decode(d2)?;
224                let s3 = alphabet_decode(d3)?;
225                let s4 = alphabet_decode(d4)?;
226                writer
227                    .write_all(&[s1 << 2 | s2 >> 4, s2 << 4 | s3 >> 2, s3 << 6 | s4])
228                    .unwrap();
229            }
230            (Some(d1), None, None, None) => {
231                alphabet_decode(d1)?;
232                return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
233            }
234            (Some(d1), Some(d2), None, None) => {
235                alphabet_decode(d1)?;
236                alphabet_decode(d2)?;
237                return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
238            }
239            (Some(d1), Some(d2), Some(d3), None) => {
240                alphabet_decode(d1)?;
241                alphabet_decode(d2)?;
242                alphabet_decode(d3)?;
243                return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
244            }
245            _ => {
246                return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
247            }
248        }
249    }
250    Ok(())
251}
252
253#[inline]
254fn alphabet_decode(d: u8) -> Result<u8> {
255    if d > 0x7A {
256        Err(ExprError::Parse(
257            format!(
258                "invalid symbol \"{}\" while decoding base64 sequence",
259                std::char::from_u32(d as u32).unwrap()
260            )
261            .into(),
262        ))
263    } else {
264        let p = PARSE_BASE64_ALPHABET_DECODE_TABLE[d as usize];
265        if p == 0x7f {
266            Err(ExprError::Parse(
267                format!(
268                    "invalid symbol \"{}\" while decoding base64 sequence",
269                    std::char::from_u32(d as u32).unwrap()
270                )
271                .into(),
272            ))
273        } else {
274            Ok(p)
275        }
276    }
277}
278
279#[inline]
280fn next(idx: &mut usize, data: &[u8]) -> Option<u8> {
281    while *idx < data.len() && PARSE_BASE64_IGNORE_BYTES.contains(&data[*idx]) {
282        *idx += 1;
283    }
284    if *idx < data.len() {
285        let d1 = data[*idx];
286        *idx += 1;
287        Some(d1)
288    } else {
289        None
290    }
291}
292
293// According to https://www.postgresql.org/docs/current/functions-binarystring.html#ENCODE-FORMAT-ESCAPE
294// The escape format converts \0 and bytes with the high bit set into octal escape sequences (\nnn).
295// And doubles backslashes.
296fn encode_bytes_escape(data: &[u8], writer: &mut impl std::fmt::Write) -> std::fmt::Result {
297    for b in data {
298        match b {
299            b'\0' | (b'\x80'..=b'\xff') => {
300                write!(writer, "\\{:03o}", b).unwrap();
301            }
302            b'\\' => writer.write_str("\\\\")?,
303            _ => writer.write_char((*b).into())?,
304        }
305    }
306    Ok(())
307}
308
309#[cfg(test)]
310mod tests {
311    use super::{decode, encode};
312
313    #[test]
314    fn test_encdec() {
315        let cases = [
316            (r#"ABCDE"#.as_bytes(), "base64", r#"QUJDREU="#.as_bytes()),
317            (r#"\""#.as_bytes(), "escape", r#"\\""#.as_bytes()),
318            (b"\x00\x40\x41\x42\xff", "escape", r"\000@AB\377".as_bytes()),
319            (
320                "aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeefffffff".as_bytes(),
321                "base64",
322                "YWFhYWFhYWFhYWJiYmJiYmJiYmJjY2NjY2NjY2NjZGRkZGRkZGRkZGVlZWVlZWVlZWVmZmZmZmZm\n"
323                    .as_bytes(),
324            ),
325            (
326                "aabbccddee".as_bytes(),
327                "hex",
328                "61616262636364646565".as_bytes(),
329            ),
330        ];
331
332        for (ori, format, encoded) in cases {
333            let mut w = String::new();
334            assert!(encode(ori, format, &mut w).is_ok());
335            println!("{}", w);
336            assert_eq!(w.as_bytes(), encoded);
337            let mut res = Vec::new();
338            decode(w.as_str(), format, &mut res).unwrap();
339            assert_eq!(ori, &res);
340        }
341    }
342}