risingwave_expr_impl/scalar/
encdec.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Write;
16
17use risingwave_common::cast::{parse_bytes_hex, parse_bytes_traditional};
18use risingwave_expr::{ExprError, Result, function};
19use thiserror_ext::AsReport;
20
21const PARSE_BASE64_INVALID_END: &str = "invalid base64 end sequence";
22const PARSE_BASE64_INVALID_PADDING: &str = "unexpected \"=\" while decoding base64 sequence";
23const PARSE_BASE64_ALPHABET: &[u8] =
24    b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
25const PARSE_BASE64_IGNORE_BYTES: [u8; 4] = [0x0D, 0x0A, 0x20, 0x09];
26// such as  'A'/0x41 -> 0  'B'/0x42 -> 1
27const PARSE_BASE64_ALPHABET_DECODE_TABLE: [u8; 123] = [
28    0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
29    0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
30    0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x3E, 0x7f, 0x7f, 0x7f, 0x3F,
31    0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
32    0x7f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
33    0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
34    0x7f, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
35    0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33,
36];
37
38#[function("encode(bytea, varchar) -> varchar")]
39pub fn encode(data: &[u8], format: &str, writer: &mut impl Write) -> Result<()> {
40    match format {
41        "base64" => {
42            encode_bytes_base64(data, writer)?;
43        }
44        "hex" => {
45            writer.write_str(&hex::encode(data)).unwrap();
46        }
47        "escape" => {
48            encode_bytes_escape(data, writer).unwrap();
49        }
50        _ => {
51            return Err(ExprError::InvalidParam {
52                name: "format",
53                reason: format!("unrecognized encoding: \"{}\"", format).into(),
54            });
55        }
56    }
57    Ok(())
58}
59
60#[function("decode(varchar, varchar) -> bytea")]
61pub fn decode(data: &str, format: &str) -> Result<Box<[u8]>> {
62    match format {
63        "base64" => Ok(parse_bytes_base64(data)?.into()),
64        "hex" => Ok(parse_bytes_hex(data)
65            .map_err(|err| ExprError::Parse(err.into()))?
66            .into()),
67        "escape" => Ok(parse_bytes_traditional(data)
68            .map_err(|err| ExprError::Parse(err.into()))?
69            .into()),
70        _ => Err(ExprError::InvalidParam {
71            name: "format",
72            reason: format!("unrecognized encoding: \"{}\"", format).into(),
73        }),
74    }
75}
76
77enum CharacterSet {
78    Utf8,
79}
80
81impl CharacterSet {
82    fn recognize(encoding: &str) -> Result<Self> {
83        match encoding.to_uppercase().as_str() {
84            "UTF8" | "UTF-8" => Ok(Self::Utf8),
85            _ => Err(ExprError::InvalidParam {
86                name: "encoding",
87                reason: format!("unrecognized encoding: \"{}\"", encoding).into(),
88            }),
89        }
90    }
91}
92
93#[function("convert_from(bytea, varchar) -> varchar")]
94pub fn convert_from(data: &[u8], src_encoding: &str, writer: &mut impl Write) -> Result<()> {
95    match CharacterSet::recognize(src_encoding)? {
96        CharacterSet::Utf8 => {
97            let text = String::from_utf8(data.to_vec()).map_err(|e| ExprError::InvalidParam {
98                name: "data",
99                reason: e.to_report_string().into(),
100            })?;
101            writer.write_str(&text).unwrap();
102            Ok(())
103        }
104    }
105}
106
107#[function("convert_to(varchar, varchar) -> bytea")]
108pub fn convert_to(string: &str, dest_encoding: &str) -> Result<Box<[u8]>> {
109    match CharacterSet::recognize(dest_encoding)? {
110        CharacterSet::Utf8 => Ok(string.as_bytes().into()),
111    }
112}
113
114// According to https://www.postgresql.org/docs/current/functions-binarystring.html#ENCODE-FORMAT-BASE64
115// We need to split newlines when the output length is greater than or equal to 76
116fn encode_bytes_base64(data: &[u8], writer: &mut impl Write) -> Result<()> {
117    let mut idx: usize = 0;
118    let len = data.len();
119    let mut written = 0;
120    while idx + 2 < len {
121        let i1 = (data[idx] >> 2) & 0b00111111;
122        let i2 = ((data[idx] & 0b00000011) << 4) | ((data[idx + 1] >> 4) & 0b00001111);
123        let i3 = ((data[idx + 1] & 0b00001111) << 2) | ((data[idx + 2] >> 6) & 0b00000011);
124        let i4 = data[idx + 2] & 0b00111111;
125        writer
126            .write_char(PARSE_BASE64_ALPHABET[usize::from(i1)].into())
127            .unwrap();
128        writer
129            .write_char(PARSE_BASE64_ALPHABET[usize::from(i2)].into())
130            .unwrap();
131        writer
132            .write_char(PARSE_BASE64_ALPHABET[usize::from(i3)].into())
133            .unwrap();
134        writer
135            .write_char(PARSE_BASE64_ALPHABET[usize::from(i4)].into())
136            .unwrap();
137
138        written += 4;
139        idx += 3;
140        if written % 76 == 0 {
141            writer.write_char('\n').unwrap();
142        }
143    }
144
145    if idx + 2 == len {
146        let i1 = (data[idx] >> 2) & 0b00111111;
147        let i2 = ((data[idx] & 0b00000011) << 4) | ((data[idx + 1] >> 4) & 0b00001111);
148        let i3 = (data[idx + 1] & 0b00001111) << 2;
149        writer
150            .write_char(PARSE_BASE64_ALPHABET[usize::from(i1)].into())
151            .unwrap();
152        writer
153            .write_char(PARSE_BASE64_ALPHABET[usize::from(i2)].into())
154            .unwrap();
155        writer
156            .write_char(PARSE_BASE64_ALPHABET[usize::from(i3)].into())
157            .unwrap();
158        writer.write_char('=').unwrap();
159    } else if idx + 1 == len {
160        let i1 = (data[idx] >> 2) & 0b00111111;
161        let i2 = (data[idx] & 0b00000011) << 4;
162        writer
163            .write_char(PARSE_BASE64_ALPHABET[usize::from(i1)].into())
164            .unwrap();
165        writer
166            .write_char(PARSE_BASE64_ALPHABET[usize::from(i2)].into())
167            .unwrap();
168        writer.write_char('=').unwrap();
169        writer.write_char('=').unwrap();
170    }
171    Ok(())
172}
173
174// According to https://www.postgresql.org/docs/current/functions-binarystring.html#ENCODE-FORMAT-BASE64
175// parse_bytes_base64 need ignores carriage-return[0x0D], newline[0x0A], space[0x20], and tab[0x09].
176// When decode is supplied invalid base64 data, including incorrect trailing padding, return error.
177fn parse_bytes_base64(data: &str) -> Result<Vec<u8>> {
178    let mut out = Vec::new();
179    let data_bytes = data.as_bytes();
180
181    let mut idx: usize = 0;
182    while idx < data.len() {
183        match (
184            next(&mut idx, data_bytes),
185            next(&mut idx, data_bytes),
186            next(&mut idx, data_bytes),
187            next(&mut idx, data_bytes),
188        ) {
189            (None, None, None, None) => return Ok(out),
190            (Some(d1), Some(d2), Some(b'='), Some(b'=')) => {
191                let s1 = alphabet_decode(d1)?;
192                let s2 = alphabet_decode(d2)?;
193                out.push(s1 << 2 | s2 >> 4);
194            }
195            (Some(d1), Some(d2), Some(d3), Some(b'=')) => {
196                let s1 = alphabet_decode(d1)?;
197                let s2 = alphabet_decode(d2)?;
198                let s3 = alphabet_decode(d3)?;
199                out.push(s1 << 2 | s2 >> 4);
200                out.push(s2 << 4 | s3 >> 2);
201            }
202            (Some(b'='), _, _, _) => {
203                return Err(ExprError::Parse(PARSE_BASE64_INVALID_PADDING.into()));
204            }
205            (Some(d1), Some(b'='), _, _) => {
206                alphabet_decode(d1)?;
207                return Err(ExprError::Parse(PARSE_BASE64_INVALID_PADDING.into()));
208            }
209            (Some(d1), Some(d2), Some(b'='), _) => {
210                alphabet_decode(d1)?;
211                alphabet_decode(d2)?;
212                return Err(ExprError::Parse(PARSE_BASE64_INVALID_PADDING.into()));
213            }
214            (Some(d1), Some(d2), Some(d3), Some(d4)) => {
215                let s1 = alphabet_decode(d1)?;
216                let s2 = alphabet_decode(d2)?;
217                let s3 = alphabet_decode(d3)?;
218                let s4 = alphabet_decode(d4)?;
219                out.push(s1 << 2 | s2 >> 4);
220                out.push(s2 << 4 | s3 >> 2);
221                out.push(s3 << 6 | s4);
222            }
223            (Some(d1), None, None, None) => {
224                alphabet_decode(d1)?;
225                return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
226            }
227            (Some(d1), Some(d2), None, None) => {
228                alphabet_decode(d1)?;
229                alphabet_decode(d2)?;
230                return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
231            }
232            (Some(d1), Some(d2), Some(d3), None) => {
233                alphabet_decode(d1)?;
234                alphabet_decode(d2)?;
235                alphabet_decode(d3)?;
236                return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
237            }
238            _ => {
239                return Err(ExprError::Parse(PARSE_BASE64_INVALID_END.into()));
240            }
241        }
242    }
243    Ok(out)
244}
245
246#[inline]
247fn alphabet_decode(d: u8) -> Result<u8> {
248    if d > 0x7A {
249        Err(ExprError::Parse(
250            format!(
251                "invalid symbol \"{}\" while decoding base64 sequence",
252                std::char::from_u32(d as u32).unwrap()
253            )
254            .into(),
255        ))
256    } else {
257        let p = PARSE_BASE64_ALPHABET_DECODE_TABLE[d as usize];
258        if p == 0x7f {
259            Err(ExprError::Parse(
260                format!(
261                    "invalid symbol \"{}\" while decoding base64 sequence",
262                    std::char::from_u32(d as u32).unwrap()
263                )
264                .into(),
265            ))
266        } else {
267            Ok(p)
268        }
269    }
270}
271
272#[inline]
273fn next(idx: &mut usize, data: &[u8]) -> Option<u8> {
274    while *idx < data.len() && PARSE_BASE64_IGNORE_BYTES.contains(&data[*idx]) {
275        *idx += 1;
276    }
277    if *idx < data.len() {
278        let d1 = data[*idx];
279        *idx += 1;
280        Some(d1)
281    } else {
282        None
283    }
284}
285
286// According to https://www.postgresql.org/docs/current/functions-binarystring.html#ENCODE-FORMAT-ESCAPE
287// The escape format converts \0 and bytes with the high bit set into octal escape sequences (\nnn).
288// And doubles backslashes.
289fn encode_bytes_escape(data: &[u8], writer: &mut impl Write) -> std::fmt::Result {
290    for b in data {
291        match b {
292            b'\0' | (b'\x80'..=b'\xff') => {
293                write!(writer, "\\{:03o}", b).unwrap();
294            }
295            b'\\' => writer.write_str("\\\\")?,
296            _ => writer.write_char((*b).into())?,
297        }
298    }
299    Ok(())
300}
301
302#[cfg(test)]
303mod tests {
304    use super::{decode, encode};
305
306    #[test]
307    fn test_encdec() {
308        let cases = [
309            (r#"ABCDE"#.as_bytes(), "base64", r#"QUJDREU="#.as_bytes()),
310            (r#"\""#.as_bytes(), "escape", r#"\\""#.as_bytes()),
311            (b"\x00\x40\x41\x42\xff", "escape", r"\000@AB\377".as_bytes()),
312            (
313                "aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeefffffff".as_bytes(),
314                "base64",
315                "YWFhYWFhYWFhYWJiYmJiYmJiYmJjY2NjY2NjY2NjZGRkZGRkZGRkZGVlZWVlZWVlZWVmZmZmZmZm\n"
316                    .as_bytes(),
317            ),
318            (
319                "aabbccddee".as_bytes(),
320                "hex",
321                "61616262636364646565".as_bytes(),
322            ),
323        ];
324
325        for (ori, format, encoded) in cases {
326            let mut w = String::new();
327            assert!(encode(ori, format, &mut w).is_ok());
328            println!("{}", w);
329            assert_eq!(w.as_bytes(), encoded);
330            let res = decode(w.as_str(), format).unwrap();
331            assert_eq!(ori, res.as_ref());
332        }
333    }
334}