risingwave_expr_impl/scalar/
encrypt.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::LazyLock;
17
18use openssl::error::ErrorStack;
19use openssl::symm::{Cipher, Crypter, Mode as CipherMode};
20use regex::Regex;
21use risingwave_expr::{ExprError, Result, function};
22
23#[derive(Debug, Clone, PartialEq)]
24enum Algorithm {
25    Aes,
26}
27
28#[derive(Debug, Clone, PartialEq)]
29enum Mode {
30    Cbc,
31    Ecb,
32}
33#[derive(Debug, Clone, PartialEq)]
34enum Padding {
35    Pkcs,
36    None,
37}
38
39#[derive(Clone)]
40pub struct CipherConfig {
41    algorithm: Algorithm,
42    mode: Mode,
43    cipher: Cipher,
44    padding: Padding,
45    crypt_key: Vec<u8>,
46}
47
48/// Because `Cipher` is not `Debug`, we include algorithm, key length and mode manually.
49impl std::fmt::Debug for CipherConfig {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("CipherConfig")
52            .field("algorithm", &self.algorithm)
53            .field("key_len", &self.crypt_key.len())
54            .field("mode", &self.mode)
55            .field("padding", &self.padding)
56            .finish()
57    }
58}
59
60static CIPHER_CONFIG_RE: LazyLock<Regex> =
61    LazyLock::new(|| Regex::new(r"^(aes)(?:-(cbc|ecb))?(?:/pad:(pkcs|none))?$").unwrap());
62
63impl CipherConfig {
64    fn parse_cipher_config(key: &[u8], input: &str) -> Result<CipherConfig> {
65        let Some(caps) = CIPHER_CONFIG_RE.captures(input) else {
66            return Err(ExprError::InvalidParam {
67                name: "mode",
68                reason: format!(
69                    "invalid mode: {}, expect pattern algorithm[-mode][/pad:padding]",
70                    input
71                )
72                .into(),
73            });
74        };
75
76        let algorithm = match caps.get(1).map(|s| s.as_str()) {
77            Some("aes") => Algorithm::Aes,
78            algo => {
79                return Err(ExprError::InvalidParam {
80                    name: "mode",
81                    reason: format!("expect aes for algorithm, but got: {:?}", algo).into(),
82                });
83            }
84        };
85
86        let mode = match caps.get(2).map(|m| m.as_str()) {
87            Some("cbc") | None => Mode::Cbc, // Default to Cbc if not specified
88            Some("ecb") => Mode::Ecb,
89            Some(mode) => {
90                return Err(ExprError::InvalidParam {
91                    name: "mode",
92                    reason: format!("expect cbc or ecb for mode, but got: {}", mode).into(),
93                });
94            }
95        };
96
97        let padding = match caps.get(3).map(|m| m.as_str()) {
98            Some("pkcs") | None => Padding::Pkcs, // Default to Pkcs if not specified
99            Some("none") => Padding::None,
100            Some(padding) => {
101                return Err(ExprError::InvalidParam {
102                    name: "mode",
103                    reason: format!("expect pkcs or none for padding, but got: {}", padding).into(),
104                });
105            }
106        };
107
108        let cipher = match (&algorithm, key.len(), &mode) {
109            (Algorithm::Aes, 16, Mode::Cbc) => Cipher::aes_128_cbc(),
110            (Algorithm::Aes, 16, Mode::Ecb) => Cipher::aes_128_ecb(),
111            (Algorithm::Aes, 24, Mode::Cbc) => Cipher::aes_192_cbc(),
112            (Algorithm::Aes, 24, Mode::Ecb) => Cipher::aes_192_ecb(),
113            (Algorithm::Aes, 32, Mode::Cbc) => Cipher::aes_256_cbc(),
114            (Algorithm::Aes, 32, Mode::Ecb) => Cipher::aes_256_ecb(),
115            (Algorithm::Aes, n, Mode::Cbc | Mode::Ecb) => {
116                return Err(ExprError::InvalidParam {
117                    name: "key",
118                    reason: format!("invalid key length: {}, expect 16, 24 or 32", n).into(),
119                });
120            }
121        };
122
123        Ok(CipherConfig {
124            algorithm,
125            mode,
126            cipher,
127            padding,
128            crypt_key: key.to_vec(),
129        })
130    }
131
132    fn eval(&self, input: &[u8], stage: CryptographyStage) -> Result<Box<[u8]>, CryptographyError> {
133        let operation = match stage {
134            CryptographyStage::Encrypt => CipherMode::Encrypt,
135            CryptographyStage::Decrypt => CipherMode::Decrypt,
136        };
137        self.eval_inner(input, operation)
138            .map_err(|reason| CryptographyError { stage, reason })
139    }
140
141    fn eval_inner(
142        &self,
143        input: &[u8],
144        operation: CipherMode,
145    ) -> std::result::Result<Box<[u8]>, ErrorStack> {
146        // Default the IV to all-zeros when the cipher requires one, to match pgcrypto:
147        // https://github.com/postgres/postgres/blob/REL_18_3/contrib/pgcrypto/pgcrypto.c#L325
148        let iv = self.cipher.iv_len().map(|len| vec![0u8; len]);
149        let mut decrypter = Crypter::new(
150            self.cipher,
151            operation,
152            self.crypt_key.as_ref(),
153            iv.as_deref(),
154        )?;
155        let enable_padding = match self.padding {
156            Padding::Pkcs => true,
157            Padding::None => false,
158        };
159        decrypter.pad(enable_padding);
160        let mut decrypt = vec![0; input.len() + self.cipher.block_size()];
161        let count = decrypter.update(input, &mut decrypt)?;
162        let rest = decrypter.finalize(&mut decrypt[count..])?;
163        decrypt.truncate(count + rest);
164        Ok(decrypt.into())
165    }
166}
167
168/// from [pg doc](https://www.postgresql.org/docs/current/pgcrypto.html#PGCRYPTO-RAW-ENC-FUNCS)
169#[function(
170    "decrypt(bytea, bytea, varchar) -> bytea",
171    prebuild = "CipherConfig::parse_cipher_config($1, $2)?"
172)]
173fn decrypt(data: &[u8], config: &CipherConfig) -> Result<Box<[u8]>, CryptographyError> {
174    config.eval(data, CryptographyStage::Decrypt)
175}
176
177#[function(
178    "encrypt(bytea, bytea, varchar) -> bytea",
179    prebuild = "CipherConfig::parse_cipher_config($1, $2)?"
180)]
181fn encrypt(data: &[u8], config: &CipherConfig) -> Result<Box<[u8]>, CryptographyError> {
182    config.eval(data, CryptographyStage::Encrypt)
183}
184
185#[derive(Debug)]
186enum CryptographyStage {
187    Encrypt,
188    Decrypt,
189}
190
191#[derive(Debug, thiserror::Error)]
192#[error("{stage:?} stage, reason: {reason}")]
193struct CryptographyError {
194    pub stage: CryptographyStage,
195    #[source]
196    pub reason: openssl::error::ErrorStack,
197}
198
199#[cfg(test)]
200mod test {
201    use super::*;
202
203    #[test]
204    fn test_decrypt() {
205        let data = b"hello world";
206        let mode = "aes";
207
208        let config = CipherConfig::parse_cipher_config(
209            b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F" as &[u8],
210            mode,
211        )
212        .unwrap();
213        let encrypted = encrypt(data, &config).unwrap();
214
215        // Pin the zero-IV default: AES-128-CBC / PKCS, key 00..0F, IV all-zero.
216        let expected: &[u8] = &[
217            0x92, 0x76, 0xfd, 0xf3, 0x84, 0xf3, 0x85, 0x18, 0xfa, 0x6c, 0x83, 0x10, 0xf1, 0x91,
218            0x67, 0x8d,
219        ];
220        assert_eq!(encrypted.as_ref(), expected);
221
222        let decrypted = decrypt(&encrypted, &config).unwrap();
223        assert_eq!(decrypted, (*data).into());
224    }
225
226    #[test]
227    fn encrypt_testcase() {
228        let encrypt_wrapper = |data: &[u8], key: &[u8], mode: &str| -> Box<[u8]> {
229            let config = CipherConfig::parse_cipher_config(key, mode).unwrap();
230            encrypt(data, &config).unwrap()
231        };
232        let decrypt_wrapper = |data: &[u8], key: &[u8], mode: &str| -> Box<[u8]> {
233            let config = CipherConfig::parse_cipher_config(key, mode).unwrap();
234            decrypt(data, &config).unwrap()
235        };
236        let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f";
237
238        let encrypted = encrypt_wrapper(
239            b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff",
240            key,
241            "aes-ecb/pad:none",
242        );
243
244        let decrypted = decrypt_wrapper(&encrypted, key, "aes-ecb/pad:none");
245        assert_eq!(
246            decrypted,
247            (*b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff").into()
248        )
249    }
250
251    #[test]
252    fn test_parse_cipher_config() {
253        let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f";
254
255        let mode_1 = "aes-ecb/pad:none";
256        let config = CipherConfig::parse_cipher_config(key, mode_1).unwrap();
257        assert_eq!(config.algorithm, Algorithm::Aes);
258        assert_eq!(config.mode, Mode::Ecb);
259        assert_eq!(config.padding, Padding::None);
260
261        let mode_2 = "aes-cbc/pad:pkcs";
262        let config = CipherConfig::parse_cipher_config(key, mode_2).unwrap();
263        assert_eq!(config.algorithm, Algorithm::Aes);
264        assert_eq!(config.mode, Mode::Cbc);
265        assert_eq!(config.padding, Padding::Pkcs);
266
267        let mode_3 = "aes";
268        let config = CipherConfig::parse_cipher_config(key, mode_3).unwrap();
269        assert_eq!(config.algorithm, Algorithm::Aes);
270        assert_eq!(config.mode, Mode::Cbc);
271        assert_eq!(config.padding, Padding::Pkcs);
272
273        let mode_4 = "cbc";
274        assert!(CipherConfig::parse_cipher_config(key, mode_4).is_err());
275    }
276}