risingwave_expr_impl/scalar/
encrypt.rs1use std::fmt::Debug;
16use std::sync::LazyLock;
17
18use openssl::error::ErrorStack;
19use openssl::symm::{Cipher, Crypter, Mode as CipherMode};
20use regex::Regex;
21use risingwave_expr::{ExprError, Result, function};
22
23#[derive(Debug, Clone, PartialEq)]
24enum Algorithm {
25 Aes,
26}
27
28#[derive(Debug, Clone, PartialEq)]
29enum Mode {
30 Cbc,
31 Ecb,
32}
33#[derive(Debug, Clone, PartialEq)]
34enum Padding {
35 Pkcs,
36 None,
37}
38
39#[derive(Clone)]
40pub struct CipherConfig {
41 algorithm: Algorithm,
42 mode: Mode,
43 cipher: Cipher,
44 padding: Padding,
45 crypt_key: Vec<u8>,
46}
47
48impl std::fmt::Debug for CipherConfig {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("CipherConfig")
52 .field("algorithm", &self.algorithm)
53 .field("key_len", &self.crypt_key.len())
54 .field("mode", &self.mode)
55 .field("padding", &self.padding)
56 .finish()
57 }
58}
59
60static CIPHER_CONFIG_RE: LazyLock<Regex> =
61 LazyLock::new(|| Regex::new(r"^(aes)(?:-(cbc|ecb))?(?:/pad:(pkcs|none))?$").unwrap());
62
63impl CipherConfig {
64 fn parse_cipher_config(key: &[u8], input: &str) -> Result<CipherConfig> {
65 let Some(caps) = CIPHER_CONFIG_RE.captures(input) else {
66 return Err(ExprError::InvalidParam {
67 name: "mode",
68 reason: format!(
69 "invalid mode: {}, expect pattern algorithm[-mode][/pad:padding]",
70 input
71 )
72 .into(),
73 });
74 };
75
76 let algorithm = match caps.get(1).map(|s| s.as_str()) {
77 Some("aes") => Algorithm::Aes,
78 algo => {
79 return Err(ExprError::InvalidParam {
80 name: "mode",
81 reason: format!("expect aes for algorithm, but got: {:?}", algo).into(),
82 });
83 }
84 };
85
86 let mode = match caps.get(2).map(|m| m.as_str()) {
87 Some("cbc") | None => Mode::Cbc, Some("ecb") => Mode::Ecb,
89 Some(mode) => {
90 return Err(ExprError::InvalidParam {
91 name: "mode",
92 reason: format!("expect cbc or ecb for mode, but got: {}", mode).into(),
93 });
94 }
95 };
96
97 let padding = match caps.get(3).map(|m| m.as_str()) {
98 Some("pkcs") | None => Padding::Pkcs, Some("none") => Padding::None,
100 Some(padding) => {
101 return Err(ExprError::InvalidParam {
102 name: "mode",
103 reason: format!("expect pkcs or none for padding, but got: {}", padding).into(),
104 });
105 }
106 };
107
108 let cipher = match (&algorithm, key.len(), &mode) {
109 (Algorithm::Aes, 16, Mode::Cbc) => Cipher::aes_128_cbc(),
110 (Algorithm::Aes, 16, Mode::Ecb) => Cipher::aes_128_ecb(),
111 (Algorithm::Aes, 24, Mode::Cbc) => Cipher::aes_192_cbc(),
112 (Algorithm::Aes, 24, Mode::Ecb) => Cipher::aes_192_ecb(),
113 (Algorithm::Aes, 32, Mode::Cbc) => Cipher::aes_256_cbc(),
114 (Algorithm::Aes, 32, Mode::Ecb) => Cipher::aes_256_ecb(),
115 (Algorithm::Aes, n, Mode::Cbc | Mode::Ecb) => {
116 return Err(ExprError::InvalidParam {
117 name: "key",
118 reason: format!("invalid key length: {}, expect 16, 24 or 32", n).into(),
119 });
120 }
121 };
122
123 Ok(CipherConfig {
124 algorithm,
125 mode,
126 cipher,
127 padding,
128 crypt_key: key.to_vec(),
129 })
130 }
131
132 fn eval(&self, input: &[u8], stage: CryptographyStage) -> Result<Box<[u8]>, CryptographyError> {
133 let operation = match stage {
134 CryptographyStage::Encrypt => CipherMode::Encrypt,
135 CryptographyStage::Decrypt => CipherMode::Decrypt,
136 };
137 self.eval_inner(input, operation)
138 .map_err(|reason| CryptographyError { stage, reason })
139 }
140
141 fn eval_inner(
142 &self,
143 input: &[u8],
144 operation: CipherMode,
145 ) -> std::result::Result<Box<[u8]>, ErrorStack> {
146 let mut decrypter = Crypter::new(self.cipher, operation, self.crypt_key.as_ref(), None)?;
147 let enable_padding = match self.padding {
148 Padding::Pkcs => true,
149 Padding::None => false,
150 };
151 decrypter.pad(enable_padding);
152 let mut decrypt = vec![0; input.len() + self.cipher.block_size()];
153 let count = decrypter.update(input, &mut decrypt)?;
154 let rest = decrypter.finalize(&mut decrypt[count..])?;
155 decrypt.truncate(count + rest);
156 Ok(decrypt.into())
157 }
158}
159
160#[function(
162 "decrypt(bytea, bytea, varchar) -> bytea",
163 prebuild = "CipherConfig::parse_cipher_config($1, $2)?"
164)]
165fn decrypt(data: &[u8], config: &CipherConfig) -> Result<Box<[u8]>, CryptographyError> {
166 config.eval(data, CryptographyStage::Decrypt)
167}
168
169#[function(
170 "encrypt(bytea, bytea, varchar) -> bytea",
171 prebuild = "CipherConfig::parse_cipher_config($1, $2)?"
172)]
173fn encrypt(data: &[u8], config: &CipherConfig) -> Result<Box<[u8]>, CryptographyError> {
174 config.eval(data, CryptographyStage::Encrypt)
175}
176
177#[derive(Debug)]
178enum CryptographyStage {
179 Encrypt,
180 Decrypt,
181}
182
183#[derive(Debug, thiserror::Error)]
184#[error("{stage:?} stage, reason: {reason}")]
185struct CryptographyError {
186 pub stage: CryptographyStage,
187 #[source]
188 pub reason: openssl::error::ErrorStack,
189}
190
191#[cfg(test)]
192mod test {
193 use super::*;
194
195 #[test]
196 fn test_decrypt() {
197 let data = b"hello world";
198 let mode = "aes";
199
200 let config = CipherConfig::parse_cipher_config(
201 b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F" as &[u8],
202 mode,
203 )
204 .unwrap();
205 let encrypted = encrypt(data, &config).unwrap();
206
207 let decrypted = decrypt(&encrypted, &config).unwrap();
208 assert_eq!(decrypted, (*data).into());
209 }
210
211 #[test]
212 fn encrypt_testcase() {
213 let encrypt_wrapper = |data: &[u8], key: &[u8], mode: &str| -> Box<[u8]> {
214 let config = CipherConfig::parse_cipher_config(key, mode).unwrap();
215 encrypt(data, &config).unwrap()
216 };
217 let decrypt_wrapper = |data: &[u8], key: &[u8], mode: &str| -> Box<[u8]> {
218 let config = CipherConfig::parse_cipher_config(key, mode).unwrap();
219 decrypt(data, &config).unwrap()
220 };
221 let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f";
222
223 let encrypted = encrypt_wrapper(
224 b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff",
225 key,
226 "aes-ecb/pad:none",
227 );
228
229 let decrypted = decrypt_wrapper(&encrypted, key, "aes-ecb/pad:none");
230 assert_eq!(
231 decrypted,
232 (*b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff").into()
233 )
234 }
235
236 #[test]
237 fn test_parse_cipher_config() {
238 let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f";
239
240 let mode_1 = "aes-ecb/pad:none";
241 let config = CipherConfig::parse_cipher_config(key, mode_1).unwrap();
242 assert_eq!(config.algorithm, Algorithm::Aes);
243 assert_eq!(config.mode, Mode::Ecb);
244 assert_eq!(config.padding, Padding::None);
245
246 let mode_2 = "aes-cbc/pad:pkcs";
247 let config = CipherConfig::parse_cipher_config(key, mode_2).unwrap();
248 assert_eq!(config.algorithm, Algorithm::Aes);
249 assert_eq!(config.mode, Mode::Cbc);
250 assert_eq!(config.padding, Padding::Pkcs);
251
252 let mode_3 = "aes";
253 let config = CipherConfig::parse_cipher_config(key, mode_3).unwrap();
254 assert_eq!(config.algorithm, Algorithm::Aes);
255 assert_eq!(config.mode, Mode::Cbc);
256 assert_eq!(config.padding, Padding::Pkcs);
257
258 let mode_4 = "cbc";
259 assert!(CipherConfig::parse_cipher_config(key, mode_4).is_err());
260 }
261}