1use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::config::HbaEntry;
25use risingwave_common::id::DatabaseId;
26use risingwave_common::types::DataType;
27use risingwave_common::util::runtime::BackgroundShutdownRuntime;
28use risingwave_common::util::tokio_util::sync::CancellationToken;
29use risingwave_sqlparser::ast::Statement;
30use serde::Deserialize;
31use thiserror_ext::AsReport;
32
33use crate::error::{PsqlError, PsqlResult};
34use crate::ldap_auth::LdapAuthenticator;
35use crate::net::{AddressRef, Listener, TcpKeepalive};
36use crate::pg_field_descriptor::PgFieldDescriptor;
37use crate::pg_message::TransactionStatus;
38use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
39use crate::pg_response::{PgResponse, ValuesStream};
40use crate::types::Format;
41
42pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
43type ProcessId = i32;
44type SecretKey = i32;
45pub type SessionId = (ProcessId, SecretKey);
46
47pub trait SessionManager: Send + Sync + 'static {
50 type Error: Into<BoxedError>;
51 type Session: Session<Error = Self::Error>;
52
53 fn create_dummy_session(
56 &self,
57 database_id: DatabaseId,
58 ) -> Result<Arc<Self::Session>, Self::Error>;
59
60 fn connect(
61 &self,
62 database: &str,
63 user_name: &str,
64 peer_addr: AddressRef,
65 ) -> Result<Arc<Self::Session>, Self::Error>;
66
67 fn cancel_queries_in_session(&self, session_id: SessionId);
68
69 fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
70
71 fn end_session(&self, session: &Self::Session);
72
73 fn shutdown(&self) -> impl Future<Output = ()> + Send {
75 async {}
76 }
77}
78
79pub trait Session: Send + Sync {
82 type Error: Into<BoxedError>;
83 type ValuesStream: ValuesStream;
84 type PreparedStatement: Send + Clone + 'static;
85 type Portal: Send + Clone + std::fmt::Display + 'static;
86
87 fn run_one_query(
90 self: Arc<Self>,
91 stmt: Statement,
92 format: Format,
93 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
94
95 fn parse(
96 self: Arc<Self>,
97 sql: Option<Statement>,
98 params_types: Vec<Option<DataType>>,
99 ) -> impl Future<Output = Result<Self::PreparedStatement, Self::Error>> + Send;
100
101 fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
105
106 fn bind(
107 self: Arc<Self>,
108 prepare_statement: Self::PreparedStatement,
109 params: Vec<Option<Bytes>>,
110 param_formats: Vec<Format>,
111 result_formats: Vec<Format>,
112 ) -> Result<Self::Portal, Self::Error>;
113
114 fn execute(
115 self: Arc<Self>,
116 portal: Self::Portal,
117 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
118
119 fn describe_statement(
120 self: Arc<Self>,
121 prepare_statement: Self::PreparedStatement,
122 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error>;
123
124 fn describe_portal(
125 self: Arc<Self>,
126 portal: Self::Portal,
127 ) -> Result<Vec<PgFieldDescriptor>, Self::Error>;
128
129 fn user_authenticator(&self) -> &UserAuthenticator;
130
131 fn id(&self) -> SessionId;
132
133 fn get_config(&self, key: &str) -> Result<String, Self::Error>;
134
135 fn set_config(&self, key: &str, value: String) -> Result<String, Self::Error>;
136
137 fn transaction_status(&self) -> TransactionStatus;
138
139 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
140
141 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
142
143 fn user(&self) -> String;
144}
145
146pub struct ExecContext {
149 pub running_sql: Arc<str>,
150 pub last_instant: Instant,
152 pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
154}
155
156pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
159
160impl ExecContextGuard {
161 pub fn new(exec_context: Arc<ExecContext>) -> Self {
162 Self(exec_context)
163 }
164}
165
166impl Drop for ExecContext {
167 fn drop(&mut self) {
168 *self.last_idle_instant.lock() = Some(Instant::now());
169 }
170}
171
172#[derive(Debug, Clone)]
173pub enum UserAuthenticator {
174 None,
176 ClearText(Vec<u8>),
178 Md5WithSalt {
180 encrypted_password: Vec<u8>,
181 salt: [u8; 4],
182 },
183 OAuth {
184 metadata: HashMap<String, String>,
185 cluster_id: String,
186 },
187 Ldap(String, HbaEntry),
188}
189
190#[derive(Debug, Deserialize)]
194struct Jwks {
195 keys: Vec<Jwk>,
196}
197
198#[derive(Debug, Deserialize)]
201struct Jwk {
202 kid: String, alg: String, n: String, e: String, }
207
208async fn validate_jwt(
209 jwt: &str,
210 jwks_url: &str,
211 issuer: &str,
212 cluster_id: &str,
213 metadata: &HashMap<String, String>,
214) -> Result<bool, BoxedError> {
215 let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
216 validate_jwt_with_jwks(jwt, &jwks, issuer, cluster_id, metadata)
217}
218
219fn audience_from_cluster_id(cluster_id: &str) -> String {
220 format!("urn:risingwave:cluster:{}", cluster_id)
221}
222
223fn validate_jwt_with_jwks(
224 jwt: &str,
225 jwks: &Jwks,
226 issuer: &str,
227 cluster_id: &str,
228 metadata: &HashMap<String, String>,
229) -> Result<bool, BoxedError> {
230 let header = decode_header(jwt)?;
231
232 let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
234 let jwk = jwks
235 .keys
236 .iter()
237 .find(|k| k.kid == kid)
238 .ok_or(format!("No matching key found in JWKS for kid: '{}'", kid))?;
239
240 if Algorithm::from_str(&jwk.alg)? != header.alg {
242 return Err("alg in jwt header does not match with alg in jwk".into());
243 }
244
245 let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
247 let mut validation = Validation::new(header.alg);
248 validation.set_issuer(&[issuer]);
249 validation.set_audience(&[audience_from_cluster_id(cluster_id)]); validation.set_required_spec_claims(&["exp", "iss", "aud"]);
251 let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
252
253 if !metadata.iter().all(
255 |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
256 ) {
257 return Err("metadata in jwt does not match with metadata declared with user".into());
258 }
259 Ok(true)
260}
261
262impl UserAuthenticator {
263 pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
264 let success = match self {
265 UserAuthenticator::None => true,
266 UserAuthenticator::ClearText(text) => password == text,
267 UserAuthenticator::Md5WithSalt {
268 encrypted_password, ..
269 } => encrypted_password == password,
270 UserAuthenticator::OAuth {
271 metadata,
272 cluster_id,
273 } => {
274 let mut metadata = metadata.clone();
275 let jwks_url = metadata.remove("jwks_url").unwrap();
276 let issuer = metadata.remove("issuer").unwrap();
277 validate_jwt(
278 &String::from_utf8_lossy(password),
279 &jwks_url,
280 &issuer,
281 cluster_id,
282 &metadata,
283 )
284 .await
285 .map_err(PsqlError::StartupError)?
286 }
287 UserAuthenticator::Ldap(user_name, hba_entry) => {
288 let ldap_auth = LdapAuthenticator::new(hba_entry)?;
289 let password_str = String::from_utf8_lossy(password).into_owned();
291 ldap_auth.authenticate(user_name, &password_str).await?
292 }
293 };
294 if !success {
295 return Err(PsqlError::PasswordError);
296 }
297 Ok(())
298 }
299}
300
301pub async fn pg_serve(
305 addr: &str,
306 tcp_keepalive: TcpKeepalive,
307 session_mgr: Arc<impl SessionManager>,
308 context: ConnectionContext,
309 shutdown: CancellationToken,
310) -> Result<(), BoxedError> {
311 let listener = Listener::bind(addr).await?;
312 tracing::info!(addr, "server started");
313
314 let acceptor_runtime = BackgroundShutdownRuntime::from({
315 let mut builder = tokio::runtime::Builder::new_multi_thread();
316 builder.worker_threads(1);
317 builder
318 .thread_name("rw-acceptor")
319 .enable_all()
320 .build()
321 .unwrap()
322 });
323
324 #[cfg(not(madsim))]
325 let worker_runtime = tokio::runtime::Handle::current();
326 #[cfg(madsim)]
327 let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
328 let session_mgr_clone = session_mgr.clone();
329 let f = async move {
330 loop {
331 let conn_ret = listener.accept(&tcp_keepalive).await;
332 match conn_ret {
333 Ok((stream, peer_addr)) => {
334 tracing::info!(%peer_addr, "accept connection");
335 worker_runtime.spawn(handle_connection(
336 stream,
337 session_mgr_clone.clone(),
338 Arc::new(peer_addr),
339 context.clone(),
340 ));
341 }
342
343 Err(e) => {
344 tracing::error!(error = %e.as_report(), "failed to accept connection",);
345 }
346 }
347 }
348 };
349 acceptor_runtime.spawn(f);
350
351 shutdown.cancelled().await;
353
354 drop(acceptor_runtime);
356 session_mgr.shutdown().await;
358
359 Ok(())
360}
361
362pub async fn handle_connection<S, SM>(
363 stream: S,
364 session_mgr: Arc<SM>,
365 peer_addr: AddressRef,
366 context: ConnectionContext,
367) where
368 S: PgByteStream,
369 SM: SessionManager,
370{
371 PgProtocol::new(stream, session_mgr, peer_addr, context)
372 .run()
373 .await;
374}
375#[cfg(test)]
376mod tests {
377 use std::sync::Arc;
378 use std::time::Instant;
379
380 use bytes::Bytes;
381 use futures::StreamExt;
382 use futures::stream::BoxStream;
383 use risingwave_common::id::DatabaseId;
384 use risingwave_common::types::DataType;
385 use risingwave_common::util::tokio_util::sync::CancellationToken;
386 use risingwave_sqlparser::ast::Statement;
387 use tokio_postgres::NoTls;
388
389 use crate::error::PsqlResult;
390 use crate::memory_manager::MessageMemoryManager;
391 use crate::pg_field_descriptor::PgFieldDescriptor;
392 use crate::pg_message::TransactionStatus;
393 use crate::pg_protocol::ConnectionContext;
394 use crate::pg_response::{PgResponse, RowSetResult, StatementType};
395 use crate::pg_server::{
396 BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
397 UserAuthenticator, pg_serve,
398 };
399 use crate::types;
400 use crate::types::Row;
401
402 struct MockSessionManager {}
403 struct MockSession {}
404
405 impl SessionManager for MockSessionManager {
406 type Error = BoxedError;
407 type Session = MockSession;
408
409 fn create_dummy_session(
410 &self,
411 _database_id: DatabaseId,
412 ) -> Result<Arc<Self::Session>, Self::Error> {
413 unimplemented!()
414 }
415
416 fn connect(
417 &self,
418 _database: &str,
419 _user_name: &str,
420 _peer_addr: crate::net::AddressRef,
421 ) -> Result<Arc<Self::Session>, Self::Error> {
422 Ok(Arc::new(MockSession {}))
423 }
424
425 fn cancel_queries_in_session(&self, _session_id: SessionId) {
426 todo!()
427 }
428
429 fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
430 todo!()
431 }
432
433 fn end_session(&self, _session: &Self::Session) {}
434 }
435
436 impl Session for MockSession {
437 type Error = BoxedError;
438 type Portal = String;
439 type PreparedStatement = String;
440 type ValuesStream = BoxStream<'static, RowSetResult>;
441
442 async fn run_one_query(
443 self: Arc<Self>,
444 _stmt: Statement,
445 _format: types::Format,
446 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
447 Ok(PgResponse::builder(StatementType::SELECT)
448 .values(
449 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
450 .boxed(),
451 vec![
452 PgFieldDescriptor::new("".to_owned(), 1043, -1);
455 1
456 ],
457 )
458 .into())
459 }
460
461 async fn parse(
462 self: Arc<Self>,
463 _sql: Option<Statement>,
464 _params_types: Vec<Option<DataType>>,
465 ) -> Result<String, Self::Error> {
466 Ok(String::new())
467 }
468
469 fn bind(
470 self: Arc<Self>,
471 _prepare_statement: String,
472 _params: Vec<Option<Bytes>>,
473 _param_formats: Vec<types::Format>,
474 _result_formats: Vec<types::Format>,
475 ) -> Result<String, Self::Error> {
476 Ok(String::new())
477 }
478
479 async fn execute(
480 self: Arc<Self>,
481 _portal: String,
482 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
483 Ok(PgResponse::builder(StatementType::SELECT)
484 .values(
485 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
486 .boxed(),
487 vec![
488 PgFieldDescriptor::new("".to_owned(), 1043, -1);
491 1
492 ],
493 )
494 .into())
495 }
496
497 fn describe_statement(
498 self: Arc<Self>,
499 _statement: String,
500 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error> {
501 Ok((
502 vec![],
503 vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
504 ))
505 }
506
507 fn describe_portal(
508 self: Arc<Self>,
509 _portal: String,
510 ) -> Result<Vec<PgFieldDescriptor>, Self::Error> {
511 Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
512 }
513
514 fn user_authenticator(&self) -> &UserAuthenticator {
515 &UserAuthenticator::None
516 }
517
518 fn id(&self) -> SessionId {
519 (0, 0)
520 }
521
522 fn get_config(&self, key: &str) -> Result<String, Self::Error> {
523 match key {
524 "timezone" => Ok("UTC".to_owned()),
525 _ => Err(format!("Unknown config key: {key}").into()),
526 }
527 }
528
529 fn set_config(&self, _key: &str, _value: String) -> Result<String, Self::Error> {
530 Ok("".to_owned())
531 }
532
533 async fn next_notice(self: &Arc<Self>) -> String {
534 std::future::pending().await
535 }
536
537 fn transaction_status(&self) -> TransactionStatus {
538 TransactionStatus::Idle
539 }
540
541 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
542 let exec_context = Arc::new(ExecContext {
543 running_sql: sql,
544 last_instant: Instant::now(),
545 last_idle_instant: Default::default(),
546 });
547 ExecContextGuard::new(exec_context)
548 }
549
550 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
551 Ok(())
552 }
553
554 fn user(&self) -> String {
555 "mock".to_owned()
556 }
557 }
558
559 async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
560 let bind_addr = bind_addr.into();
561 let pg_config = pg_config.into();
562
563 let session_mgr = MockSessionManager {};
564 tokio::spawn(async move {
565 pg_serve(
566 &bind_addr,
567 socket2::TcpKeepalive::new(),
568 Arc::new(session_mgr),
569 ConnectionContext {
570 tls_config: None,
571 redact_sql_option_keywords: None,
572 message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
573 .into(),
574 },
575 CancellationToken::new(), )
577 .await
578 });
579 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
581
582 let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
584
585 tokio::spawn(async move {
588 if let Err(e) = connection.await {
589 eprintln!("connection error: {}", e);
590 }
591 });
592
593 let rows = client
594 .simple_query("SELECT ''")
595 .await
596 .expect("Error executing query");
597 assert_eq!(rows.len(), 2);
599
600 let rows = client
601 .query("SELECT ''", &[])
602 .await
603 .expect("Error executing query");
604 assert_eq!(rows.len(), 1);
605 }
606
607 #[tokio::test]
608 async fn test_query_tcp() {
609 do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
610 }
611
612 #[cfg(not(madsim))]
613 #[tokio::test]
614 async fn test_query_unix() {
615 let port: i16 = 10000;
616 let dir = tempfile::TempDir::new().unwrap();
617 let sock = dir.path().join(format!(".s.PGSQL.{port}"));
618
619 do_test_query(
620 format!("unix:{}", sock.to_str().unwrap()),
621 format!("host={} port={}", dir.path().to_str().unwrap(), port),
622 )
623 .await;
624 }
625
626 mod jwt_validation_tests {
627 use std::collections::HashMap;
628 use std::time::{SystemTime, UNIX_EPOCH};
629
630 use base64::Engine;
631 use jsonwebtoken::{Algorithm, EncodingKey, Header};
632 use rsa::pkcs1::EncodeRsaPrivateKey;
633 use rsa::traits::PublicKeyParts;
634 use rsa::{RsaPrivateKey, RsaPublicKey};
635 use serde_json::json;
636
637 use crate::pg_server::{Jwk, Jwks, validate_jwt_with_jwks};
638
639 fn create_test_rsa_keys() -> (RsaPrivateKey, RsaPublicKey) {
640 let mut rng = rand::thread_rng();
641 let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate a key");
642 let public_key = RsaPublicKey::from(&private_key);
643 (private_key, public_key)
644 }
645
646 fn create_test_jwks(public_key: &RsaPublicKey, kid: &str, alg: &str) -> Jwks {
647 let n = base64::engine::general_purpose::URL_SAFE_NO_PAD
648 .encode(public_key.n().to_bytes_be());
649 let e = base64::engine::general_purpose::URL_SAFE_NO_PAD
650 .encode(public_key.e().to_bytes_be());
651
652 Jwks {
653 keys: vec![Jwk {
654 kid: kid.to_owned(),
655 alg: alg.to_owned(),
656 n,
657 e,
658 }],
659 }
660 }
661
662 fn create_jwt_token(
663 private_key: &RsaPrivateKey,
664 kid: &str,
665 algorithm: Algorithm,
666 issuer: &str,
667 audience: Option<&str>,
668 exp: u64,
669 additional_claims: HashMap<String, serde_json::Value>,
670 ) -> String {
671 let mut header = Header::new(algorithm);
672 header.kid = Some(kid.to_owned());
673
674 let mut claims = json!({
675 "iss": issuer,
676 "exp": exp,
677 });
678
679 if let Some(aud) = audience {
680 claims["aud"] = json!(aud);
681 }
682
683 for (key, value) in additional_claims {
684 claims[key] = value;
685 }
686
687 let encoding_key = EncodingKey::from_rsa_pem(
688 private_key
689 .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
690 .unwrap()
691 .as_bytes(),
692 )
693 .unwrap();
694
695 jsonwebtoken::encode(&header, &claims, &encoding_key).unwrap()
696 }
697
698 fn get_future_timestamp() -> u64 {
699 SystemTime::now()
700 .duration_since(UNIX_EPOCH)
701 .unwrap()
702 .as_secs()
703 + 3600 }
705
706 fn get_past_timestamp() -> u64 {
707 SystemTime::now()
708 .duration_since(UNIX_EPOCH)
709 .unwrap()
710 .as_secs()
711 - 3600 }
713
714 #[test]
715 fn test_jwt_with_invalid_audience() {
716 let (private_key, public_key) = create_test_rsa_keys();
717 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
718
719 let metadata = HashMap::new();
720
721 let jwt = create_jwt_token(
722 &private_key,
723 "test-kid",
724 Algorithm::RS256,
725 "https://test-issuer.com",
726 Some("urn:risingwave:cluster:wrong-cluster-id"),
727 get_future_timestamp(),
728 HashMap::new(),
729 );
730
731 let result = validate_jwt_with_jwks(
732 &jwt,
733 &jwks,
734 "https://test-issuer.com",
735 "test-cluster-id",
736 &metadata,
737 );
738
739 let error = result.unwrap_err();
740 assert!(error.to_string().contains("InvalidAudience"));
741 }
742
743 #[test]
744 fn test_jwt_with_missing_audience() {
745 let (private_key, public_key) = create_test_rsa_keys();
746 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
747
748 let metadata = HashMap::new();
749
750 let jwt = create_jwt_token(
751 &private_key,
752 "test-kid",
753 Algorithm::RS256,
754 "https://test-issuer.com",
755 None, get_future_timestamp(),
757 HashMap::new(),
758 );
759
760 let result = validate_jwt_with_jwks(
761 &jwt,
762 &jwks,
763 "https://test-issuer.com",
764 "test-cluster-id",
765 &metadata,
766 );
767
768 let error = result.unwrap_err();
769 assert!(error.to_string().contains("Missing required claim: aud"));
770 }
771
772 #[test]
773 fn test_jwt_with_invalid_issuer() {
774 let (private_key, public_key) = create_test_rsa_keys();
775 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
776
777 let metadata = HashMap::new();
778
779 let jwt = create_jwt_token(
780 &private_key,
781 "test-kid",
782 Algorithm::RS256,
783 "https://wrong-issuer.com",
784 Some("urn:risingwave:cluster:test-cluster-id"),
785 get_future_timestamp(),
786 HashMap::new(),
787 );
788
789 let result = validate_jwt_with_jwks(
790 &jwt,
791 &jwks,
792 "https://test-issuer.com",
793 "test-cluster-id",
794 &metadata,
795 );
796
797 let error = result.unwrap_err();
798 assert!(error.to_string().contains("InvalidIssuer"));
799 }
800
801 #[test]
802 fn test_jwt_with_kid_not_found_in_jwks() {
803 let (private_key, public_key) = create_test_rsa_keys();
804 let jwks = create_test_jwks(&public_key, "different-kid", "RS256");
805
806 let metadata = HashMap::new();
807
808 let jwt = create_jwt_token(
809 &private_key,
810 "missing-kid",
811 Algorithm::RS256,
812 "https://test-issuer.com",
813 Some("urn:risingwave:cluster:test-cluster-id"),
814 get_future_timestamp(),
815 HashMap::new(),
816 );
817
818 let result = validate_jwt_with_jwks(
819 &jwt,
820 &jwks,
821 "https://test-issuer.com",
822 "test-cluster-id",
823 &metadata,
824 );
825
826 let error = result.unwrap_err();
827 assert!(
828 error
829 .to_string()
830 .contains("No matching key found in JWKS for kid: 'missing-kid'")
831 );
832 }
833
834 #[test]
835 fn test_jwt_with_expired_token() {
836 let (private_key, public_key) = create_test_rsa_keys();
837 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
838
839 let metadata = HashMap::new();
840
841 let jwt = create_jwt_token(
842 &private_key,
843 "test-kid",
844 Algorithm::RS256,
845 "https://test-issuer.com",
846 Some("urn:risingwave:cluster:test-cluster-id"),
847 get_past_timestamp(), HashMap::new(),
849 );
850
851 let result = validate_jwt_with_jwks(
852 &jwt,
853 &jwks,
854 "https://test-issuer.com",
855 "test-cluster-id",
856 &metadata,
857 );
858
859 let error = result.unwrap_err();
860 assert!(error.to_string().contains("ExpiredSignature"));
861 }
862
863 #[test]
864 fn test_jwt_with_invalid_signature() {
865 let (_, public_key) = create_test_rsa_keys();
866 let (wrong_private_key, _) = create_test_rsa_keys(); let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
868
869 let metadata = HashMap::new();
870
871 let jwt = create_jwt_token(
873 &wrong_private_key,
874 "test-kid",
875 Algorithm::RS256,
876 "https://test-issuer.com",
877 Some("urn:risingwave:cluster:test-cluster-id"),
878 get_future_timestamp(),
879 HashMap::new(),
880 );
881
882 let result = validate_jwt_with_jwks(
883 &jwt,
884 &jwks,
885 "https://test-issuer.com",
886 "test-cluster-id",
887 &metadata,
888 );
889
890 let error = result.unwrap_err();
891 assert!(error.to_string().contains("InvalidSignature"));
892 }
893
894 #[test]
895 fn test_metadata_validation_success() {
896 let (private_key, public_key) = create_test_rsa_keys();
897 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
898
899 let mut metadata = HashMap::new();
900 metadata.insert("role".to_owned(), "admin".to_owned());
901 metadata.insert("department".to_owned(), "security".to_owned());
902
903 let mut claims = HashMap::new();
904 claims.insert("role".to_owned(), json!("admin"));
905 claims.insert("department".to_owned(), json!("security"));
906 claims.insert("extra_claim".to_owned(), json!("ignored")); let jwt = create_jwt_token(
909 &private_key,
910 "test-kid",
911 Algorithm::RS256,
912 "https://test-issuer.com",
913 Some("urn:risingwave:cluster:test-cluster-id"),
914 get_future_timestamp(),
915 claims,
916 );
917
918 let result = validate_jwt_with_jwks(
919 &jwt,
920 &jwks,
921 "https://test-issuer.com",
922 "test-cluster-id",
923 &metadata,
924 );
925
926 assert!(result.unwrap());
927 }
928
929 #[test]
930 fn test_metadata_validation_failure() {
931 let (private_key, public_key) = create_test_rsa_keys();
932 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
933
934 let mut metadata = HashMap::new();
935 metadata.insert("role".to_owned(), "admin".to_owned());
936 metadata.insert("department".to_owned(), "security".to_owned());
937
938 let mut claims = HashMap::new();
939 claims.insert("role".to_owned(), json!("user")); claims.insert("department".to_owned(), json!("security"));
941
942 let jwt = create_jwt_token(
943 &private_key,
944 "test-kid",
945 Algorithm::RS256,
946 "https://test-issuer.com",
947 Some("urn:risingwave:cluster:test-cluster-id"),
948 get_future_timestamp(),
949 claims,
950 );
951
952 let result = validate_jwt_with_jwks(
953 &jwt,
954 &jwks,
955 "https://test-issuer.com",
956 "test-cluster-id",
957 &metadata,
958 );
959
960 let error = result.unwrap_err();
961 assert_eq!(
962 error.to_string(),
963 "metadata in jwt does not match with metadata declared with user"
964 );
965 }
966 }
967}