1use std::fmt::Debug;
16use std::sync::LazyLock;
17
18use openssl::error::ErrorStack;
19use openssl::symm::{Cipher, Crypter, Mode as CipherMode};
20use regex::Regex;
21use risingwave_expr::{ExprError, Result, function};
22
23#[derive(Debug, Clone, PartialEq)]
24enum Algorithm {
25 Aes,
26}
27
28#[derive(Debug, Clone, PartialEq)]
29enum Mode {
30 Cbc,
31 Ecb,
32}
33#[derive(Debug, Clone, PartialEq)]
34enum Padding {
35 Pkcs,
36 None,
37}
38
39#[derive(Clone)]
40pub struct CipherConfig {
41 algorithm: Algorithm,
42 mode: Mode,
43 cipher: Cipher,
44 padding: Padding,
45 crypt_key: Vec<u8>,
46}
47
48impl std::fmt::Debug for CipherConfig {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("CipherConfig")
52 .field("algorithm", &self.algorithm)
53 .field("key_len", &self.crypt_key.len())
54 .field("mode", &self.mode)
55 .field("padding", &self.padding)
56 .finish()
57 }
58}
59
60static CIPHER_CONFIG_RE: LazyLock<Regex> =
61 LazyLock::new(|| Regex::new(r"^(aes)(?:-(cbc|ecb))?(?:/pad:(pkcs|none))?$").unwrap());
62
63impl CipherConfig {
64 fn parse_cipher_config(key: &[u8], input: &str) -> Result<CipherConfig> {
65 let Some(caps) = CIPHER_CONFIG_RE.captures(input) else {
66 return Err(ExprError::InvalidParam {
67 name: "mode",
68 reason: format!(
69 "invalid mode: {}, expect pattern algorithm[-mode][/pad:padding]",
70 input
71 )
72 .into(),
73 });
74 };
75
76 let algorithm = match caps.get(1).map(|s| s.as_str()) {
77 Some("aes") => Algorithm::Aes,
78 algo => {
79 return Err(ExprError::InvalidParam {
80 name: "mode",
81 reason: format!("expect aes for algorithm, but got: {:?}", algo).into(),
82 });
83 }
84 };
85
86 let mode = match caps.get(2).map(|m| m.as_str()) {
87 Some("cbc") | None => Mode::Cbc, Some("ecb") => Mode::Ecb,
89 Some(mode) => {
90 return Err(ExprError::InvalidParam {
91 name: "mode",
92 reason: format!("expect cbc or ecb for mode, but got: {}", mode).into(),
93 });
94 }
95 };
96
97 let padding = match caps.get(3).map(|m| m.as_str()) {
98 Some("pkcs") | None => Padding::Pkcs, Some("none") => Padding::None,
100 Some(padding) => {
101 return Err(ExprError::InvalidParam {
102 name: "mode",
103 reason: format!("expect pkcs or none for padding, but got: {}", padding).into(),
104 });
105 }
106 };
107
108 let cipher = match (&algorithm, key.len(), &mode) {
109 (Algorithm::Aes, 16, Mode::Cbc) => Cipher::aes_128_cbc(),
110 (Algorithm::Aes, 16, Mode::Ecb) => Cipher::aes_128_ecb(),
111 (Algorithm::Aes, 24, Mode::Cbc) => Cipher::aes_192_cbc(),
112 (Algorithm::Aes, 24, Mode::Ecb) => Cipher::aes_192_ecb(),
113 (Algorithm::Aes, 32, Mode::Cbc) => Cipher::aes_256_cbc(),
114 (Algorithm::Aes, 32, Mode::Ecb) => Cipher::aes_256_ecb(),
115 (Algorithm::Aes, n, Mode::Cbc | Mode::Ecb) => {
116 return Err(ExprError::InvalidParam {
117 name: "key",
118 reason: format!("invalid key length: {}, expect 16, 24 or 32", n).into(),
119 });
120 }
121 };
122
123 Ok(CipherConfig {
124 algorithm,
125 mode,
126 cipher,
127 padding,
128 crypt_key: key.to_vec(),
129 })
130 }
131
132 fn eval(&self, input: &[u8], stage: CryptographyStage) -> Result<Box<[u8]>, CryptographyError> {
133 let operation = match stage {
134 CryptographyStage::Encrypt => CipherMode::Encrypt,
135 CryptographyStage::Decrypt => CipherMode::Decrypt,
136 };
137 self.eval_inner(input, operation)
138 .map_err(|reason| CryptographyError { stage, reason })
139 }
140
141 fn eval_inner(
142 &self,
143 input: &[u8],
144 operation: CipherMode,
145 ) -> std::result::Result<Box<[u8]>, ErrorStack> {
146 let iv = self.cipher.iv_len().map(|len| vec![0u8; len]);
149 let mut decrypter = Crypter::new(
150 self.cipher,
151 operation,
152 self.crypt_key.as_ref(),
153 iv.as_deref(),
154 )?;
155 let enable_padding = match self.padding {
156 Padding::Pkcs => true,
157 Padding::None => false,
158 };
159 decrypter.pad(enable_padding);
160 let mut decrypt = vec![0; input.len() + self.cipher.block_size()];
161 let count = decrypter.update(input, &mut decrypt)?;
162 let rest = decrypter.finalize(&mut decrypt[count..])?;
163 decrypt.truncate(count + rest);
164 Ok(decrypt.into())
165 }
166}
167
168#[function(
170 "decrypt(bytea, bytea, varchar) -> bytea",
171 prebuild = "CipherConfig::parse_cipher_config($1, $2)?"
172)]
173fn decrypt(data: &[u8], config: &CipherConfig) -> Result<Box<[u8]>, CryptographyError> {
174 config.eval(data, CryptographyStage::Decrypt)
175}
176
177#[function(
178 "encrypt(bytea, bytea, varchar) -> bytea",
179 prebuild = "CipherConfig::parse_cipher_config($1, $2)?"
180)]
181fn encrypt(data: &[u8], config: &CipherConfig) -> Result<Box<[u8]>, CryptographyError> {
182 config.eval(data, CryptographyStage::Encrypt)
183}
184
185#[derive(Debug)]
186enum CryptographyStage {
187 Encrypt,
188 Decrypt,
189}
190
191#[derive(Debug, thiserror::Error)]
192#[error("{stage:?} stage, reason: {reason}")]
193struct CryptographyError {
194 pub stage: CryptographyStage,
195 #[source]
196 pub reason: openssl::error::ErrorStack,
197}
198
199#[cfg(test)]
200mod test {
201 use super::*;
202
203 #[test]
204 fn test_decrypt() {
205 let data = b"hello world";
206 let mode = "aes";
207
208 let config = CipherConfig::parse_cipher_config(
209 b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F" as &[u8],
210 mode,
211 )
212 .unwrap();
213 let encrypted = encrypt(data, &config).unwrap();
214
215 let expected: &[u8] = &[
217 0x92, 0x76, 0xfd, 0xf3, 0x84, 0xf3, 0x85, 0x18, 0xfa, 0x6c, 0x83, 0x10, 0xf1, 0x91,
218 0x67, 0x8d,
219 ];
220 assert_eq!(encrypted.as_ref(), expected);
221
222 let decrypted = decrypt(&encrypted, &config).unwrap();
223 assert_eq!(decrypted, (*data).into());
224 }
225
226 #[test]
227 fn encrypt_testcase() {
228 let encrypt_wrapper = |data: &[u8], key: &[u8], mode: &str| -> Box<[u8]> {
229 let config = CipherConfig::parse_cipher_config(key, mode).unwrap();
230 encrypt(data, &config).unwrap()
231 };
232 let decrypt_wrapper = |data: &[u8], key: &[u8], mode: &str| -> Box<[u8]> {
233 let config = CipherConfig::parse_cipher_config(key, mode).unwrap();
234 decrypt(data, &config).unwrap()
235 };
236 let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f";
237
238 let encrypted = encrypt_wrapper(
239 b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff",
240 key,
241 "aes-ecb/pad:none",
242 );
243
244 let decrypted = decrypt_wrapper(&encrypted, key, "aes-ecb/pad:none");
245 assert_eq!(
246 decrypted,
247 (*b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff").into()
248 )
249 }
250
251 #[test]
252 fn test_parse_cipher_config() {
253 let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f";
254
255 let mode_1 = "aes-ecb/pad:none";
256 let config = CipherConfig::parse_cipher_config(key, mode_1).unwrap();
257 assert_eq!(config.algorithm, Algorithm::Aes);
258 assert_eq!(config.mode, Mode::Ecb);
259 assert_eq!(config.padding, Padding::None);
260
261 let mode_2 = "aes-cbc/pad:pkcs";
262 let config = CipherConfig::parse_cipher_config(key, mode_2).unwrap();
263 assert_eq!(config.algorithm, Algorithm::Aes);
264 assert_eq!(config.mode, Mode::Cbc);
265 assert_eq!(config.padding, Padding::Pkcs);
266
267 let mode_3 = "aes";
268 let config = CipherConfig::parse_cipher_config(key, mode_3).unwrap();
269 assert_eq!(config.algorithm, Algorithm::Aes);
270 assert_eq!(config.mode, Mode::Cbc);
271 assert_eq!(config.padding, Padding::Pkcs);
272
273 let mode_4 = "cbc";
274 assert!(CipherConfig::parse_cipher_config(key, mode_4).is_err());
275 }
276}