risingwave_expr_impl/scalar/
encrypt.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::LazyLock;
17
18use openssl::error::ErrorStack;
19use openssl::symm::{Cipher, Crypter, Mode as CipherMode};
20use regex::Regex;
21use risingwave_expr::{ExprError, Result, function};
22
23#[derive(Debug, Clone, PartialEq)]
24enum Algorithm {
25    Aes,
26}
27
28#[derive(Debug, Clone, PartialEq)]
29enum Mode {
30    Cbc,
31    Ecb,
32}
33#[derive(Debug, Clone, PartialEq)]
34enum Padding {
35    Pkcs,
36    None,
37}
38
39#[derive(Clone)]
40pub struct CipherConfig {
41    algorithm: Algorithm,
42    mode: Mode,
43    cipher: Cipher,
44    padding: Padding,
45    crypt_key: Vec<u8>,
46}
47
48/// Because `Cipher` is not `Debug`, we include algorithm, key length and mode manually.
49impl std::fmt::Debug for CipherConfig {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("CipherConfig")
52            .field("algorithm", &self.algorithm)
53            .field("key_len", &self.crypt_key.len())
54            .field("mode", &self.mode)
55            .field("padding", &self.padding)
56            .finish()
57    }
58}
59
60static CIPHER_CONFIG_RE: LazyLock<Regex> =
61    LazyLock::new(|| Regex::new(r"^(aes)(?:-(cbc|ecb))?(?:/pad:(pkcs|none))?$").unwrap());
62
63impl CipherConfig {
64    fn parse_cipher_config(key: &[u8], input: &str) -> Result<CipherConfig> {
65        let Some(caps) = CIPHER_CONFIG_RE.captures(input) else {
66            return Err(ExprError::InvalidParam {
67                name: "mode",
68                reason: format!(
69                    "invalid mode: {}, expect pattern algorithm[-mode][/pad:padding]",
70                    input
71                )
72                .into(),
73            });
74        };
75
76        let algorithm = match caps.get(1).map(|s| s.as_str()) {
77            Some("aes") => Algorithm::Aes,
78            algo => {
79                return Err(ExprError::InvalidParam {
80                    name: "mode",
81                    reason: format!("expect aes for algorithm, but got: {:?}", algo).into(),
82                });
83            }
84        };
85
86        let mode = match caps.get(2).map(|m| m.as_str()) {
87            Some("cbc") | None => Mode::Cbc, // Default to Cbc if not specified
88            Some("ecb") => Mode::Ecb,
89            Some(mode) => {
90                return Err(ExprError::InvalidParam {
91                    name: "mode",
92                    reason: format!("expect cbc or ecb for mode, but got: {}", mode).into(),
93                });
94            }
95        };
96
97        let padding = match caps.get(3).map(|m| m.as_str()) {
98            Some("pkcs") | None => Padding::Pkcs, // Default to Pkcs if not specified
99            Some("none") => Padding::None,
100            Some(padding) => {
101                return Err(ExprError::InvalidParam {
102                    name: "mode",
103                    reason: format!("expect pkcs or none for padding, but got: {}", padding).into(),
104                });
105            }
106        };
107
108        let cipher = match (&algorithm, key.len(), &mode) {
109            (Algorithm::Aes, 16, Mode::Cbc) => Cipher::aes_128_cbc(),
110            (Algorithm::Aes, 16, Mode::Ecb) => Cipher::aes_128_ecb(),
111            (Algorithm::Aes, 24, Mode::Cbc) => Cipher::aes_192_cbc(),
112            (Algorithm::Aes, 24, Mode::Ecb) => Cipher::aes_192_ecb(),
113            (Algorithm::Aes, 32, Mode::Cbc) => Cipher::aes_256_cbc(),
114            (Algorithm::Aes, 32, Mode::Ecb) => Cipher::aes_256_ecb(),
115            (Algorithm::Aes, n, Mode::Cbc | Mode::Ecb) => {
116                return Err(ExprError::InvalidParam {
117                    name: "key",
118                    reason: format!("invalid key length: {}, expect 16, 24 or 32", n).into(),
119                });
120            }
121        };
122
123        Ok(CipherConfig {
124            algorithm,
125            mode,
126            cipher,
127            padding,
128            crypt_key: key.to_vec(),
129        })
130    }
131
132    fn eval(&self, input: &[u8], stage: CryptographyStage) -> Result<Box<[u8]>, CryptographyError> {
133        let operation = match stage {
134            CryptographyStage::Encrypt => CipherMode::Encrypt,
135            CryptographyStage::Decrypt => CipherMode::Decrypt,
136        };
137        self.eval_inner(input, operation)
138            .map_err(|reason| CryptographyError { stage, reason })
139    }
140
141    fn eval_inner(
142        &self,
143        input: &[u8],
144        operation: CipherMode,
145    ) -> std::result::Result<Box<[u8]>, ErrorStack> {
146        let mut decrypter = Crypter::new(self.cipher, operation, self.crypt_key.as_ref(), None)?;
147        let enable_padding = match self.padding {
148            Padding::Pkcs => true,
149            Padding::None => false,
150        };
151        decrypter.pad(enable_padding);
152        let mut decrypt = vec![0; input.len() + self.cipher.block_size()];
153        let count = decrypter.update(input, &mut decrypt)?;
154        let rest = decrypter.finalize(&mut decrypt[count..])?;
155        decrypt.truncate(count + rest);
156        Ok(decrypt.into())
157    }
158}
159
160/// from [pg doc](https://www.postgresql.org/docs/current/pgcrypto.html#PGCRYPTO-RAW-ENC-FUNCS)
161#[function(
162    "decrypt(bytea, bytea, varchar) -> bytea",
163    prebuild = "CipherConfig::parse_cipher_config($1, $2)?"
164)]
165fn decrypt(data: &[u8], config: &CipherConfig) -> Result<Box<[u8]>, CryptographyError> {
166    config.eval(data, CryptographyStage::Decrypt)
167}
168
169#[function(
170    "encrypt(bytea, bytea, varchar) -> bytea",
171    prebuild = "CipherConfig::parse_cipher_config($1, $2)?"
172)]
173fn encrypt(data: &[u8], config: &CipherConfig) -> Result<Box<[u8]>, CryptographyError> {
174    config.eval(data, CryptographyStage::Encrypt)
175}
176
177#[derive(Debug)]
178enum CryptographyStage {
179    Encrypt,
180    Decrypt,
181}
182
183#[derive(Debug, thiserror::Error)]
184#[error("{stage:?} stage, reason: {reason}")]
185struct CryptographyError {
186    pub stage: CryptographyStage,
187    #[source]
188    pub reason: openssl::error::ErrorStack,
189}
190
191#[cfg(test)]
192mod test {
193    use super::*;
194
195    #[test]
196    fn test_decrypt() {
197        let data = b"hello world";
198        let mode = "aes";
199
200        let config = CipherConfig::parse_cipher_config(
201            b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F" as &[u8],
202            mode,
203        )
204        .unwrap();
205        let encrypted = encrypt(data, &config).unwrap();
206
207        let decrypted = decrypt(&encrypted, &config).unwrap();
208        assert_eq!(decrypted, (*data).into());
209    }
210
211    #[test]
212    fn encrypt_testcase() {
213        let encrypt_wrapper = |data: &[u8], key: &[u8], mode: &str| -> Box<[u8]> {
214            let config = CipherConfig::parse_cipher_config(key, mode).unwrap();
215            encrypt(data, &config).unwrap()
216        };
217        let decrypt_wrapper = |data: &[u8], key: &[u8], mode: &str| -> Box<[u8]> {
218            let config = CipherConfig::parse_cipher_config(key, mode).unwrap();
219            decrypt(data, &config).unwrap()
220        };
221        let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f";
222
223        let encrypted = encrypt_wrapper(
224            b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff",
225            key,
226            "aes-ecb/pad:none",
227        );
228
229        let decrypted = decrypt_wrapper(&encrypted, key, "aes-ecb/pad:none");
230        assert_eq!(
231            decrypted,
232            (*b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff").into()
233        )
234    }
235
236    #[test]
237    fn test_parse_cipher_config() {
238        let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f";
239
240        let mode_1 = "aes-ecb/pad:none";
241        let config = CipherConfig::parse_cipher_config(key, mode_1).unwrap();
242        assert_eq!(config.algorithm, Algorithm::Aes);
243        assert_eq!(config.mode, Mode::Ecb);
244        assert_eq!(config.padding, Padding::None);
245
246        let mode_2 = "aes-cbc/pad:pkcs";
247        let config = CipherConfig::parse_cipher_config(key, mode_2).unwrap();
248        assert_eq!(config.algorithm, Algorithm::Aes);
249        assert_eq!(config.mode, Mode::Cbc);
250        assert_eq!(config.padding, Padding::Pkcs);
251
252        let mode_3 = "aes";
253        let config = CipherConfig::parse_cipher_config(key, mode_3).unwrap();
254        assert_eq!(config.algorithm, Algorithm::Aes);
255        assert_eq!(config.mode, Mode::Cbc);
256        assert_eq!(config.padding, Padding::Pkcs);
257
258        let mode_4 = "cbc";
259        assert!(CipherConfig::parse_cipher_config(key, mode_4).is_err());
260    }
261}