1use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::config::HbaEntry;
25use risingwave_common::id::DatabaseId;
26use risingwave_common::types::DataType;
27use risingwave_common::util::runtime::BackgroundShutdownRuntime;
28use risingwave_common::util::tokio_util::sync::CancellationToken;
29use risingwave_sqlparser::ast::Statement;
30use serde::Deserialize;
31use thiserror_ext::AsReport;
32
33use crate::error::{PsqlError, PsqlResult};
34use crate::ldap_auth::LdapAuthenticator;
35use crate::net::{AddressRef, Listener, TcpKeepalive};
36use crate::pg_field_descriptor::PgFieldDescriptor;
37use crate::pg_message::TransactionStatus;
38use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
39use crate::pg_response::{PgResponse, ValuesStream};
40use crate::types::Format;
41
42pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
43type ProcessId = i32;
44type SecretKey = i32;
45pub type SessionId = (ProcessId, SecretKey);
46
47pub trait SessionManager: Send + Sync + 'static {
50 type Error: Into<BoxedError>;
51 type Session: Session<Error = Self::Error>;
52
53 fn create_dummy_session(
56 &self,
57 database_id: DatabaseId,
58 ) -> Result<Arc<Self::Session>, Self::Error>;
59
60 fn connect(
61 &self,
62 database: &str,
63 user_name: &str,
64 peer_addr: AddressRef,
65 ) -> Result<Arc<Self::Session>, Self::Error>;
66
67 fn cancel_queries_in_session(&self, session_id: SessionId);
68
69 fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
70
71 fn end_session(&self, session: &Self::Session);
72
73 fn shutdown(&self) -> impl Future<Output = ()> + Send {
75 async {}
76 }
77}
78
79pub trait Session: Send + Sync {
82 type Error: Into<BoxedError>;
83 type ValuesStream: ValuesStream;
84 type PreparedStatement: Send + Clone + 'static;
85 type Portal: Send + Clone + std::fmt::Display + 'static;
86
87 fn run_one_query(
90 self: Arc<Self>,
91 stmt: Statement,
92 format: Format,
93 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
94
95 fn parse(
96 self: Arc<Self>,
97 sql: Option<Statement>,
98 params_types: Vec<Option<DataType>>,
99 ) -> impl Future<Output = Result<Self::PreparedStatement, Self::Error>> + Send;
100
101 fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
105
106 fn bind(
107 self: Arc<Self>,
108 prepare_statement: Self::PreparedStatement,
109 params: Vec<Option<Bytes>>,
110 param_formats: Vec<Format>,
111 result_formats: Vec<Format>,
112 ) -> Result<Self::Portal, Self::Error>;
113
114 fn execute(
115 self: Arc<Self>,
116 portal: Self::Portal,
117 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
118
119 fn describe_statement(
120 self: Arc<Self>,
121 prepare_statement: Self::PreparedStatement,
122 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error>;
123
124 fn describe_portal(
125 self: Arc<Self>,
126 portal: Self::Portal,
127 ) -> Result<Vec<PgFieldDescriptor>, Self::Error>;
128
129 fn user_authenticator(&self) -> &UserAuthenticator;
130
131 fn id(&self) -> SessionId;
132
133 fn get_config(&self, key: &str) -> Result<String, Self::Error>;
134
135 fn set_config(&self, key: &str, value: String) -> Result<String, Self::Error>;
136
137 fn transaction_status(&self) -> TransactionStatus;
138
139 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
140
141 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
142}
143
144pub struct ExecContext {
147 pub running_sql: Arc<str>,
148 pub last_instant: Instant,
150 pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
152}
153
154pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
157
158impl ExecContextGuard {
159 pub fn new(exec_context: Arc<ExecContext>) -> Self {
160 Self(exec_context)
161 }
162}
163
164impl Drop for ExecContext {
165 fn drop(&mut self) {
166 *self.last_idle_instant.lock() = Some(Instant::now());
167 }
168}
169
170#[derive(Debug, Clone)]
171pub enum UserAuthenticator {
172 None,
174 ClearText(Vec<u8>),
176 Md5WithSalt {
178 encrypted_password: Vec<u8>,
179 salt: [u8; 4],
180 },
181 OAuth {
182 metadata: HashMap<String, String>,
183 cluster_id: String,
184 },
185 Ldap(String, HbaEntry),
186}
187
188#[derive(Debug, Deserialize)]
192struct Jwks {
193 keys: Vec<Jwk>,
194}
195
196#[derive(Debug, Deserialize)]
199struct Jwk {
200 kid: String, alg: String, n: String, e: String, }
205
206async fn validate_jwt(
207 jwt: &str,
208 jwks_url: &str,
209 issuer: &str,
210 cluster_id: &str,
211 metadata: &HashMap<String, String>,
212) -> Result<bool, BoxedError> {
213 let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
214 validate_jwt_with_jwks(jwt, &jwks, issuer, cluster_id, metadata)
215}
216
217fn audience_from_cluster_id(cluster_id: &str) -> String {
218 format!("urn:risingwave:cluster:{}", cluster_id)
219}
220
221fn validate_jwt_with_jwks(
222 jwt: &str,
223 jwks: &Jwks,
224 issuer: &str,
225 cluster_id: &str,
226 metadata: &HashMap<String, String>,
227) -> Result<bool, BoxedError> {
228 let header = decode_header(jwt)?;
229
230 let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
232 let jwk = jwks
233 .keys
234 .iter()
235 .find(|k| k.kid == kid)
236 .ok_or(format!("No matching key found in JWKS for kid: '{}'", kid))?;
237
238 if Algorithm::from_str(&jwk.alg)? != header.alg {
240 return Err("alg in jwt header does not match with alg in jwk".into());
241 }
242
243 let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
245 let mut validation = Validation::new(header.alg);
246 validation.set_issuer(&[issuer]);
247 validation.set_audience(&[audience_from_cluster_id(cluster_id)]); validation.set_required_spec_claims(&["exp", "iss", "aud"]);
249 let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
250
251 if !metadata.iter().all(
253 |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
254 ) {
255 return Err("metadata in jwt does not match with metadata declared with user".into());
256 }
257 Ok(true)
258}
259
260impl UserAuthenticator {
261 pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
262 let success = match self {
263 UserAuthenticator::None => true,
264 UserAuthenticator::ClearText(text) => password == text,
265 UserAuthenticator::Md5WithSalt {
266 encrypted_password, ..
267 } => encrypted_password == password,
268 UserAuthenticator::OAuth {
269 metadata,
270 cluster_id,
271 } => {
272 let mut metadata = metadata.clone();
273 let jwks_url = metadata.remove("jwks_url").unwrap();
274 let issuer = metadata.remove("issuer").unwrap();
275 validate_jwt(
276 &String::from_utf8_lossy(password),
277 &jwks_url,
278 &issuer,
279 cluster_id,
280 &metadata,
281 )
282 .await
283 .map_err(PsqlError::StartupError)?
284 }
285 UserAuthenticator::Ldap(user_name, hba_entry) => {
286 let ldap_auth = LdapAuthenticator::new(hba_entry)?;
287 let password_str = String::from_utf8_lossy(password).into_owned();
289 ldap_auth.authenticate(user_name, &password_str).await?
290 }
291 };
292 if !success {
293 return Err(PsqlError::PasswordError);
294 }
295 Ok(())
296 }
297}
298
299pub async fn pg_serve(
303 addr: &str,
304 tcp_keepalive: TcpKeepalive,
305 session_mgr: Arc<impl SessionManager>,
306 context: ConnectionContext,
307 shutdown: CancellationToken,
308) -> Result<(), BoxedError> {
309 let listener = Listener::bind(addr).await?;
310 tracing::info!(addr, "server started");
311
312 let acceptor_runtime = BackgroundShutdownRuntime::from({
313 let mut builder = tokio::runtime::Builder::new_multi_thread();
314 builder.worker_threads(1);
315 builder
316 .thread_name("rw-acceptor")
317 .enable_all()
318 .build()
319 .unwrap()
320 });
321
322 #[cfg(not(madsim))]
323 let worker_runtime = tokio::runtime::Handle::current();
324 #[cfg(madsim)]
325 let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
326 let session_mgr_clone = session_mgr.clone();
327 let f = async move {
328 loop {
329 let conn_ret = listener.accept(&tcp_keepalive).await;
330 match conn_ret {
331 Ok((stream, peer_addr)) => {
332 tracing::info!(%peer_addr, "accept connection");
333 worker_runtime.spawn(handle_connection(
334 stream,
335 session_mgr_clone.clone(),
336 Arc::new(peer_addr),
337 context.clone(),
338 ));
339 }
340
341 Err(e) => {
342 tracing::error!(error = %e.as_report(), "failed to accept connection",);
343 }
344 }
345 }
346 };
347 acceptor_runtime.spawn(f);
348
349 shutdown.cancelled().await;
351
352 drop(acceptor_runtime);
354 session_mgr.shutdown().await;
356
357 Ok(())
358}
359
360pub async fn handle_connection<S, SM>(
361 stream: S,
362 session_mgr: Arc<SM>,
363 peer_addr: AddressRef,
364 context: ConnectionContext,
365) where
366 S: PgByteStream,
367 SM: SessionManager,
368{
369 PgProtocol::new(stream, session_mgr, peer_addr, context)
370 .run()
371 .await;
372}
373#[cfg(test)]
374mod tests {
375 use std::sync::Arc;
376 use std::time::Instant;
377
378 use bytes::Bytes;
379 use futures::StreamExt;
380 use futures::stream::BoxStream;
381 use risingwave_common::id::DatabaseId;
382 use risingwave_common::types::DataType;
383 use risingwave_common::util::tokio_util::sync::CancellationToken;
384 use risingwave_sqlparser::ast::Statement;
385 use tokio_postgres::NoTls;
386
387 use crate::error::PsqlResult;
388 use crate::memory_manager::MessageMemoryManager;
389 use crate::pg_field_descriptor::PgFieldDescriptor;
390 use crate::pg_message::TransactionStatus;
391 use crate::pg_protocol::ConnectionContext;
392 use crate::pg_response::{PgResponse, RowSetResult, StatementType};
393 use crate::pg_server::{
394 BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
395 UserAuthenticator, pg_serve,
396 };
397 use crate::types;
398 use crate::types::Row;
399
400 struct MockSessionManager {}
401 struct MockSession {}
402
403 impl SessionManager for MockSessionManager {
404 type Error = BoxedError;
405 type Session = MockSession;
406
407 fn create_dummy_session(
408 &self,
409 _database_id: DatabaseId,
410 ) -> Result<Arc<Self::Session>, Self::Error> {
411 unimplemented!()
412 }
413
414 fn connect(
415 &self,
416 _database: &str,
417 _user_name: &str,
418 _peer_addr: crate::net::AddressRef,
419 ) -> Result<Arc<Self::Session>, Self::Error> {
420 Ok(Arc::new(MockSession {}))
421 }
422
423 fn cancel_queries_in_session(&self, _session_id: SessionId) {
424 todo!()
425 }
426
427 fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
428 todo!()
429 }
430
431 fn end_session(&self, _session: &Self::Session) {}
432 }
433
434 impl Session for MockSession {
435 type Error = BoxedError;
436 type Portal = String;
437 type PreparedStatement = String;
438 type ValuesStream = BoxStream<'static, RowSetResult>;
439
440 async fn run_one_query(
441 self: Arc<Self>,
442 _stmt: Statement,
443 _format: types::Format,
444 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
445 Ok(PgResponse::builder(StatementType::SELECT)
446 .values(
447 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
448 .boxed(),
449 vec![
450 PgFieldDescriptor::new("".to_owned(), 1043, -1);
453 1
454 ],
455 )
456 .into())
457 }
458
459 async fn parse(
460 self: Arc<Self>,
461 _sql: Option<Statement>,
462 _params_types: Vec<Option<DataType>>,
463 ) -> Result<String, Self::Error> {
464 Ok(String::new())
465 }
466
467 fn bind(
468 self: Arc<Self>,
469 _prepare_statement: String,
470 _params: Vec<Option<Bytes>>,
471 _param_formats: Vec<types::Format>,
472 _result_formats: Vec<types::Format>,
473 ) -> Result<String, Self::Error> {
474 Ok(String::new())
475 }
476
477 async fn execute(
478 self: Arc<Self>,
479 _portal: String,
480 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
481 Ok(PgResponse::builder(StatementType::SELECT)
482 .values(
483 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
484 .boxed(),
485 vec![
486 PgFieldDescriptor::new("".to_owned(), 1043, -1);
489 1
490 ],
491 )
492 .into())
493 }
494
495 fn describe_statement(
496 self: Arc<Self>,
497 _statement: String,
498 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error> {
499 Ok((
500 vec![],
501 vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
502 ))
503 }
504
505 fn describe_portal(
506 self: Arc<Self>,
507 _portal: String,
508 ) -> Result<Vec<PgFieldDescriptor>, Self::Error> {
509 Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
510 }
511
512 fn user_authenticator(&self) -> &UserAuthenticator {
513 &UserAuthenticator::None
514 }
515
516 fn id(&self) -> SessionId {
517 (0, 0)
518 }
519
520 fn get_config(&self, key: &str) -> Result<String, Self::Error> {
521 match key {
522 "timezone" => Ok("UTC".to_owned()),
523 _ => Err(format!("Unknown config key: {key}").into()),
524 }
525 }
526
527 fn set_config(&self, _key: &str, _value: String) -> Result<String, Self::Error> {
528 Ok("".to_owned())
529 }
530
531 async fn next_notice(self: &Arc<Self>) -> String {
532 std::future::pending().await
533 }
534
535 fn transaction_status(&self) -> TransactionStatus {
536 TransactionStatus::Idle
537 }
538
539 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
540 let exec_context = Arc::new(ExecContext {
541 running_sql: sql,
542 last_instant: Instant::now(),
543 last_idle_instant: Default::default(),
544 });
545 ExecContextGuard::new(exec_context)
546 }
547
548 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
549 Ok(())
550 }
551 }
552
553 async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
554 let bind_addr = bind_addr.into();
555 let pg_config = pg_config.into();
556
557 let session_mgr = MockSessionManager {};
558 tokio::spawn(async move {
559 pg_serve(
560 &bind_addr,
561 socket2::TcpKeepalive::new(),
562 Arc::new(session_mgr),
563 ConnectionContext {
564 tls_config: None,
565 redact_sql_option_keywords: None,
566 message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
567 .into(),
568 },
569 CancellationToken::new(), )
571 .await
572 });
573 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
575
576 let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
578
579 tokio::spawn(async move {
582 if let Err(e) = connection.await {
583 eprintln!("connection error: {}", e);
584 }
585 });
586
587 let rows = client
588 .simple_query("SELECT ''")
589 .await
590 .expect("Error executing query");
591 assert_eq!(rows.len(), 2);
593
594 let rows = client
595 .query("SELECT ''", &[])
596 .await
597 .expect("Error executing query");
598 assert_eq!(rows.len(), 1);
599 }
600
601 #[tokio::test]
602 async fn test_query_tcp() {
603 do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
604 }
605
606 #[cfg(not(madsim))]
607 #[tokio::test]
608 async fn test_query_unix() {
609 let port: i16 = 10000;
610 let dir = tempfile::TempDir::new().unwrap();
611 let sock = dir.path().join(format!(".s.PGSQL.{port}"));
612
613 do_test_query(
614 format!("unix:{}", sock.to_str().unwrap()),
615 format!("host={} port={}", dir.path().to_str().unwrap(), port),
616 )
617 .await;
618 }
619
620 mod jwt_validation_tests {
621 use std::collections::HashMap;
622 use std::time::{SystemTime, UNIX_EPOCH};
623
624 use base64::Engine;
625 use jsonwebtoken::{Algorithm, EncodingKey, Header};
626 use rsa::pkcs1::EncodeRsaPrivateKey;
627 use rsa::traits::PublicKeyParts;
628 use rsa::{RsaPrivateKey, RsaPublicKey};
629 use serde_json::json;
630
631 use crate::pg_server::{Jwk, Jwks, validate_jwt_with_jwks};
632
633 fn create_test_rsa_keys() -> (RsaPrivateKey, RsaPublicKey) {
634 let mut rng = rand::thread_rng();
635 let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate a key");
636 let public_key = RsaPublicKey::from(&private_key);
637 (private_key, public_key)
638 }
639
640 fn create_test_jwks(public_key: &RsaPublicKey, kid: &str, alg: &str) -> Jwks {
641 let n = base64::engine::general_purpose::URL_SAFE_NO_PAD
642 .encode(public_key.n().to_bytes_be());
643 let e = base64::engine::general_purpose::URL_SAFE_NO_PAD
644 .encode(public_key.e().to_bytes_be());
645
646 Jwks {
647 keys: vec![Jwk {
648 kid: kid.to_owned(),
649 alg: alg.to_owned(),
650 n,
651 e,
652 }],
653 }
654 }
655
656 fn create_jwt_token(
657 private_key: &RsaPrivateKey,
658 kid: &str,
659 algorithm: Algorithm,
660 issuer: &str,
661 audience: Option<&str>,
662 exp: u64,
663 additional_claims: HashMap<String, serde_json::Value>,
664 ) -> String {
665 let mut header = Header::new(algorithm);
666 header.kid = Some(kid.to_owned());
667
668 let mut claims = json!({
669 "iss": issuer,
670 "exp": exp,
671 });
672
673 if let Some(aud) = audience {
674 claims["aud"] = json!(aud);
675 }
676
677 for (key, value) in additional_claims {
678 claims[key] = value;
679 }
680
681 let encoding_key = EncodingKey::from_rsa_pem(
682 private_key
683 .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
684 .unwrap()
685 .as_bytes(),
686 )
687 .unwrap();
688
689 jsonwebtoken::encode(&header, &claims, &encoding_key).unwrap()
690 }
691
692 fn get_future_timestamp() -> u64 {
693 SystemTime::now()
694 .duration_since(UNIX_EPOCH)
695 .unwrap()
696 .as_secs()
697 + 3600 }
699
700 fn get_past_timestamp() -> u64 {
701 SystemTime::now()
702 .duration_since(UNIX_EPOCH)
703 .unwrap()
704 .as_secs()
705 - 3600 }
707
708 #[test]
709 fn test_jwt_with_invalid_audience() {
710 let (private_key, public_key) = create_test_rsa_keys();
711 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
712
713 let metadata = HashMap::new();
714
715 let jwt = create_jwt_token(
716 &private_key,
717 "test-kid",
718 Algorithm::RS256,
719 "https://test-issuer.com",
720 Some("urn:risingwave:cluster:wrong-cluster-id"),
721 get_future_timestamp(),
722 HashMap::new(),
723 );
724
725 let result = validate_jwt_with_jwks(
726 &jwt,
727 &jwks,
728 "https://test-issuer.com",
729 "test-cluster-id",
730 &metadata,
731 );
732
733 let error = result.unwrap_err();
734 assert!(error.to_string().contains("InvalidAudience"));
735 }
736
737 #[test]
738 fn test_jwt_with_missing_audience() {
739 let (private_key, public_key) = create_test_rsa_keys();
740 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
741
742 let metadata = HashMap::new();
743
744 let jwt = create_jwt_token(
745 &private_key,
746 "test-kid",
747 Algorithm::RS256,
748 "https://test-issuer.com",
749 None, get_future_timestamp(),
751 HashMap::new(),
752 );
753
754 let result = validate_jwt_with_jwks(
755 &jwt,
756 &jwks,
757 "https://test-issuer.com",
758 "test-cluster-id",
759 &metadata,
760 );
761
762 let error = result.unwrap_err();
763 assert!(error.to_string().contains("Missing required claim: aud"));
764 }
765
766 #[test]
767 fn test_jwt_with_invalid_issuer() {
768 let (private_key, public_key) = create_test_rsa_keys();
769 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
770
771 let metadata = HashMap::new();
772
773 let jwt = create_jwt_token(
774 &private_key,
775 "test-kid",
776 Algorithm::RS256,
777 "https://wrong-issuer.com",
778 Some("urn:risingwave:cluster:test-cluster-id"),
779 get_future_timestamp(),
780 HashMap::new(),
781 );
782
783 let result = validate_jwt_with_jwks(
784 &jwt,
785 &jwks,
786 "https://test-issuer.com",
787 "test-cluster-id",
788 &metadata,
789 );
790
791 let error = result.unwrap_err();
792 assert!(error.to_string().contains("InvalidIssuer"));
793 }
794
795 #[test]
796 fn test_jwt_with_kid_not_found_in_jwks() {
797 let (private_key, public_key) = create_test_rsa_keys();
798 let jwks = create_test_jwks(&public_key, "different-kid", "RS256");
799
800 let metadata = HashMap::new();
801
802 let jwt = create_jwt_token(
803 &private_key,
804 "missing-kid",
805 Algorithm::RS256,
806 "https://test-issuer.com",
807 Some("urn:risingwave:cluster:test-cluster-id"),
808 get_future_timestamp(),
809 HashMap::new(),
810 );
811
812 let result = validate_jwt_with_jwks(
813 &jwt,
814 &jwks,
815 "https://test-issuer.com",
816 "test-cluster-id",
817 &metadata,
818 );
819
820 let error = result.unwrap_err();
821 assert!(
822 error
823 .to_string()
824 .contains("No matching key found in JWKS for kid: 'missing-kid'")
825 );
826 }
827
828 #[test]
829 fn test_jwt_with_expired_token() {
830 let (private_key, public_key) = create_test_rsa_keys();
831 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
832
833 let metadata = HashMap::new();
834
835 let jwt = create_jwt_token(
836 &private_key,
837 "test-kid",
838 Algorithm::RS256,
839 "https://test-issuer.com",
840 Some("urn:risingwave:cluster:test-cluster-id"),
841 get_past_timestamp(), HashMap::new(),
843 );
844
845 let result = validate_jwt_with_jwks(
846 &jwt,
847 &jwks,
848 "https://test-issuer.com",
849 "test-cluster-id",
850 &metadata,
851 );
852
853 let error = result.unwrap_err();
854 assert!(error.to_string().contains("ExpiredSignature"));
855 }
856
857 #[test]
858 fn test_jwt_with_invalid_signature() {
859 let (_, public_key) = create_test_rsa_keys();
860 let (wrong_private_key, _) = create_test_rsa_keys(); let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
862
863 let metadata = HashMap::new();
864
865 let jwt = create_jwt_token(
867 &wrong_private_key,
868 "test-kid",
869 Algorithm::RS256,
870 "https://test-issuer.com",
871 Some("urn:risingwave:cluster:test-cluster-id"),
872 get_future_timestamp(),
873 HashMap::new(),
874 );
875
876 let result = validate_jwt_with_jwks(
877 &jwt,
878 &jwks,
879 "https://test-issuer.com",
880 "test-cluster-id",
881 &metadata,
882 );
883
884 let error = result.unwrap_err();
885 assert!(error.to_string().contains("InvalidSignature"));
886 }
887
888 #[test]
889 fn test_metadata_validation_success() {
890 let (private_key, public_key) = create_test_rsa_keys();
891 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
892
893 let mut metadata = HashMap::new();
894 metadata.insert("role".to_owned(), "admin".to_owned());
895 metadata.insert("department".to_owned(), "security".to_owned());
896
897 let mut claims = HashMap::new();
898 claims.insert("role".to_owned(), json!("admin"));
899 claims.insert("department".to_owned(), json!("security"));
900 claims.insert("extra_claim".to_owned(), json!("ignored")); let jwt = create_jwt_token(
903 &private_key,
904 "test-kid",
905 Algorithm::RS256,
906 "https://test-issuer.com",
907 Some("urn:risingwave:cluster:test-cluster-id"),
908 get_future_timestamp(),
909 claims,
910 );
911
912 let result = validate_jwt_with_jwks(
913 &jwt,
914 &jwks,
915 "https://test-issuer.com",
916 "test-cluster-id",
917 &metadata,
918 );
919
920 assert!(result.unwrap());
921 }
922
923 #[test]
924 fn test_metadata_validation_failure() {
925 let (private_key, public_key) = create_test_rsa_keys();
926 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
927
928 let mut metadata = HashMap::new();
929 metadata.insert("role".to_owned(), "admin".to_owned());
930 metadata.insert("department".to_owned(), "security".to_owned());
931
932 let mut claims = HashMap::new();
933 claims.insert("role".to_owned(), json!("user")); claims.insert("department".to_owned(), json!("security"));
935
936 let jwt = create_jwt_token(
937 &private_key,
938 "test-kid",
939 Algorithm::RS256,
940 "https://test-issuer.com",
941 Some("urn:risingwave:cluster:test-cluster-id"),
942 get_future_timestamp(),
943 claims,
944 );
945
946 let result = validate_jwt_with_jwks(
947 &jwt,
948 &jwks,
949 "https://test-issuer.com",
950 "test-cluster-id",
951 &metadata,
952 );
953
954 let error = result.unwrap_err();
955 assert_eq!(
956 error.to_string(),
957 "metadata in jwt does not match with metadata declared with user"
958 );
959 }
960 }
961}