1use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::config::HbaEntry;
25use risingwave_common::id::DatabaseId;
26use risingwave_common::types::DataType;
27use risingwave_common::util::runtime::BackgroundShutdownRuntime;
28use risingwave_common::util::tokio_util::sync::CancellationToken;
29use risingwave_sqlparser::ast::Statement;
30use serde::Deserialize;
31use thiserror_ext::AsReport;
32
33use crate::error::{PsqlError, PsqlResult};
34use crate::ldap_auth::LdapAuthenticator;
35use crate::net::{AddressRef, Listener, TcpKeepalive};
36use crate::pg_field_descriptor::PgFieldDescriptor;
37use crate::pg_message::TransactionStatus;
38use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
39use crate::pg_response::{PgResponse, ValuesStream};
40use crate::types::Format;
41
42pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
43type ProcessId = i32;
44type SecretKey = i32;
45pub type SessionId = (ProcessId, SecretKey);
46
47pub trait SessionManager: Send + Sync + 'static {
50 type Session: Session;
51
52 fn create_dummy_session(
55 &self,
56 database_id: DatabaseId,
57 user_id: u32,
58 ) -> Result<Arc<Self::Session>, BoxedError>;
59
60 fn connect(
61 &self,
62 database: &str,
63 user_name: &str,
64 peer_addr: AddressRef,
65 ) -> Result<Arc<Self::Session>, BoxedError>;
66
67 fn cancel_queries_in_session(&self, session_id: SessionId);
68
69 fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
70
71 fn end_session(&self, session: &Self::Session);
72
73 fn shutdown(&self) -> impl Future<Output = ()> + Send {
75 async {}
76 }
77}
78
79pub trait Session: Send + Sync {
82 type ValuesStream: ValuesStream;
83 type PreparedStatement: Send + Clone + 'static;
84 type Portal: Send + Clone + std::fmt::Display + 'static;
85
86 fn run_one_query(
89 self: Arc<Self>,
90 stmt: Statement,
91 format: Format,
92 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
93
94 fn parse(
95 self: Arc<Self>,
96 sql: Option<Statement>,
97 params_types: Vec<Option<DataType>>,
98 ) -> impl Future<Output = Result<Self::PreparedStatement, BoxedError>> + Send;
99
100 fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
104
105 fn bind(
106 self: Arc<Self>,
107 prepare_statement: Self::PreparedStatement,
108 params: Vec<Option<Bytes>>,
109 param_formats: Vec<Format>,
110 result_formats: Vec<Format>,
111 ) -> Result<Self::Portal, BoxedError>;
112
113 fn execute(
114 self: Arc<Self>,
115 portal: Self::Portal,
116 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
117
118 fn describe_statement(
119 self: Arc<Self>,
120 prepare_statement: Self::PreparedStatement,
121 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError>;
122
123 fn describe_portal(
124 self: Arc<Self>,
125 portal: Self::Portal,
126 ) -> Result<Vec<PgFieldDescriptor>, BoxedError>;
127
128 fn user_authenticator(&self) -> &UserAuthenticator;
129
130 fn id(&self) -> SessionId;
131
132 fn get_config(&self, key: &str) -> Result<String, BoxedError>;
133
134 fn set_config(&self, key: &str, value: String) -> Result<String, BoxedError>;
135
136 fn transaction_status(&self) -> TransactionStatus;
137
138 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
139
140 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
141}
142
143pub struct ExecContext {
146 pub running_sql: Arc<str>,
147 pub last_instant: Instant,
149 pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
151}
152
153pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
156
157impl ExecContextGuard {
158 pub fn new(exec_context: Arc<ExecContext>) -> Self {
159 Self(exec_context)
160 }
161}
162
163impl Drop for ExecContext {
164 fn drop(&mut self) {
165 *self.last_idle_instant.lock() = Some(Instant::now());
166 }
167}
168
169#[derive(Debug, Clone)]
170pub enum UserAuthenticator {
171 None,
173 ClearText(Vec<u8>),
175 Md5WithSalt {
177 encrypted_password: Vec<u8>,
178 salt: [u8; 4],
179 },
180 OAuth {
181 metadata: HashMap<String, String>,
182 cluster_id: String,
183 },
184 Ldap(String, HbaEntry),
185}
186
187#[derive(Debug, Deserialize)]
191struct Jwks {
192 keys: Vec<Jwk>,
193}
194
195#[derive(Debug, Deserialize)]
198struct Jwk {
199 kid: String, alg: String, n: String, e: String, }
204
205async fn validate_jwt(
206 jwt: &str,
207 jwks_url: &str,
208 issuer: &str,
209 cluster_id: &str,
210 metadata: &HashMap<String, String>,
211) -> Result<bool, BoxedError> {
212 let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
213 validate_jwt_with_jwks(jwt, &jwks, issuer, cluster_id, metadata)
214}
215
216fn audience_from_cluster_id(cluster_id: &str) -> String {
217 format!("urn:risingwave:cluster:{}", cluster_id)
218}
219
220fn validate_jwt_with_jwks(
221 jwt: &str,
222 jwks: &Jwks,
223 issuer: &str,
224 cluster_id: &str,
225 metadata: &HashMap<String, String>,
226) -> Result<bool, BoxedError> {
227 let header = decode_header(jwt)?;
228
229 let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
231 let jwk = jwks
232 .keys
233 .iter()
234 .find(|k| k.kid == kid)
235 .ok_or(format!("No matching key found in JWKS for kid: '{}'", kid))?;
236
237 if Algorithm::from_str(&jwk.alg)? != header.alg {
239 return Err("alg in jwt header does not match with alg in jwk".into());
240 }
241
242 let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
244 let mut validation = Validation::new(header.alg);
245 validation.set_issuer(&[issuer]);
246 validation.set_audience(&[audience_from_cluster_id(cluster_id)]); validation.set_required_spec_claims(&["exp", "iss", "aud"]);
248 let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
249
250 if !metadata.iter().all(
252 |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
253 ) {
254 return Err("metadata in jwt does not match with metadata declared with user".into());
255 }
256 Ok(true)
257}
258
259impl UserAuthenticator {
260 pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
261 let success = match self {
262 UserAuthenticator::None => true,
263 UserAuthenticator::ClearText(text) => password == text,
264 UserAuthenticator::Md5WithSalt {
265 encrypted_password, ..
266 } => encrypted_password == password,
267 UserAuthenticator::OAuth {
268 metadata,
269 cluster_id,
270 } => {
271 let mut metadata = metadata.clone();
272 let jwks_url = metadata.remove("jwks_url").unwrap();
273 let issuer = metadata.remove("issuer").unwrap();
274 validate_jwt(
275 &String::from_utf8_lossy(password),
276 &jwks_url,
277 &issuer,
278 cluster_id,
279 &metadata,
280 )
281 .await
282 .map_err(PsqlError::StartupError)?
283 }
284 UserAuthenticator::Ldap(user_name, hba_entry) => {
285 let ldap_auth = LdapAuthenticator::new(hba_entry)?;
286 let password_str = String::from_utf8_lossy(password).into_owned();
288 ldap_auth.authenticate(user_name, &password_str).await?
289 }
290 };
291 if !success {
292 return Err(PsqlError::PasswordError);
293 }
294 Ok(())
295 }
296}
297
298pub async fn pg_serve(
302 addr: &str,
303 tcp_keepalive: TcpKeepalive,
304 session_mgr: Arc<impl SessionManager>,
305 context: ConnectionContext,
306 shutdown: CancellationToken,
307) -> Result<(), BoxedError> {
308 let listener = Listener::bind(addr).await?;
309 tracing::info!(addr, "server started");
310
311 let acceptor_runtime = BackgroundShutdownRuntime::from({
312 let mut builder = tokio::runtime::Builder::new_multi_thread();
313 builder.worker_threads(1);
314 builder
315 .thread_name("rw-acceptor")
316 .enable_all()
317 .build()
318 .unwrap()
319 });
320
321 #[cfg(not(madsim))]
322 let worker_runtime = tokio::runtime::Handle::current();
323 #[cfg(madsim)]
324 let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
325 let session_mgr_clone = session_mgr.clone();
326 let f = async move {
327 loop {
328 let conn_ret = listener.accept(&tcp_keepalive).await;
329 match conn_ret {
330 Ok((stream, peer_addr)) => {
331 tracing::info!(%peer_addr, "accept connection");
332 worker_runtime.spawn(handle_connection(
333 stream,
334 session_mgr_clone.clone(),
335 Arc::new(peer_addr),
336 context.clone(),
337 ));
338 }
339
340 Err(e) => {
341 tracing::error!(error = %e.as_report(), "failed to accept connection",);
342 }
343 }
344 }
345 };
346 acceptor_runtime.spawn(f);
347
348 shutdown.cancelled().await;
350
351 drop(acceptor_runtime);
353 session_mgr.shutdown().await;
355
356 Ok(())
357}
358
359pub async fn handle_connection<S, SM>(
360 stream: S,
361 session_mgr: Arc<SM>,
362 peer_addr: AddressRef,
363 context: ConnectionContext,
364) where
365 S: PgByteStream,
366 SM: SessionManager,
367{
368 PgProtocol::new(stream, session_mgr, peer_addr, context)
369 .run()
370 .await;
371}
372#[cfg(test)]
373mod tests {
374 use std::error::Error;
375 use std::sync::Arc;
376 use std::time::Instant;
377
378 use bytes::Bytes;
379 use futures::StreamExt;
380 use futures::stream::BoxStream;
381 use risingwave_common::id::DatabaseId;
382 use risingwave_common::types::DataType;
383 use risingwave_common::util::tokio_util::sync::CancellationToken;
384 use risingwave_sqlparser::ast::Statement;
385 use tokio_postgres::NoTls;
386
387 use crate::error::PsqlResult;
388 use crate::memory_manager::MessageMemoryManager;
389 use crate::pg_field_descriptor::PgFieldDescriptor;
390 use crate::pg_message::TransactionStatus;
391 use crate::pg_protocol::ConnectionContext;
392 use crate::pg_response::{PgResponse, RowSetResult, StatementType};
393 use crate::pg_server::{
394 BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
395 UserAuthenticator, pg_serve,
396 };
397 use crate::types;
398 use crate::types::Row;
399
400 struct MockSessionManager {}
401 struct MockSession {}
402
403 impl SessionManager for MockSessionManager {
404 type Session = MockSession;
405
406 fn create_dummy_session(
407 &self,
408 _database_id: DatabaseId,
409 _user_name: u32,
410 ) -> Result<Arc<Self::Session>, BoxedError> {
411 unimplemented!()
412 }
413
414 fn connect(
415 &self,
416 _database: &str,
417 _user_name: &str,
418 _peer_addr: crate::net::AddressRef,
419 ) -> Result<Arc<Self::Session>, Box<dyn Error + Send + Sync>> {
420 Ok(Arc::new(MockSession {}))
421 }
422
423 fn cancel_queries_in_session(&self, _session_id: SessionId) {
424 todo!()
425 }
426
427 fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
428 todo!()
429 }
430
431 fn end_session(&self, _session: &Self::Session) {}
432 }
433
434 impl Session for MockSession {
435 type Portal = String;
436 type PreparedStatement = String;
437 type ValuesStream = BoxStream<'static, RowSetResult>;
438
439 async fn run_one_query(
440 self: Arc<Self>,
441 _stmt: Statement,
442 _format: types::Format,
443 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
444 Ok(PgResponse::builder(StatementType::SELECT)
445 .values(
446 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
447 .boxed(),
448 vec![
449 PgFieldDescriptor::new("".to_owned(), 1043, -1);
452 1
453 ],
454 )
455 .into())
456 }
457
458 async fn parse(
459 self: Arc<Self>,
460 _sql: Option<Statement>,
461 _params_types: Vec<Option<DataType>>,
462 ) -> Result<String, BoxedError> {
463 Ok(String::new())
464 }
465
466 fn bind(
467 self: Arc<Self>,
468 _prepare_statement: String,
469 _params: Vec<Option<Bytes>>,
470 _param_formats: Vec<types::Format>,
471 _result_formats: Vec<types::Format>,
472 ) -> Result<String, BoxedError> {
473 Ok(String::new())
474 }
475
476 async fn execute(
477 self: Arc<Self>,
478 _portal: String,
479 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
480 Ok(PgResponse::builder(StatementType::SELECT)
481 .values(
482 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
483 .boxed(),
484 vec![
485 PgFieldDescriptor::new("".to_owned(), 1043, -1);
488 1
489 ],
490 )
491 .into())
492 }
493
494 fn describe_statement(
495 self: Arc<Self>,
496 _statement: String,
497 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
498 Ok((
499 vec![],
500 vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
501 ))
502 }
503
504 fn describe_portal(
505 self: Arc<Self>,
506 _portal: String,
507 ) -> Result<Vec<PgFieldDescriptor>, BoxedError> {
508 Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
509 }
510
511 fn user_authenticator(&self) -> &UserAuthenticator {
512 &UserAuthenticator::None
513 }
514
515 fn id(&self) -> SessionId {
516 (0, 0)
517 }
518
519 fn get_config(&self, key: &str) -> Result<String, BoxedError> {
520 match key {
521 "timezone" => Ok("UTC".to_owned()),
522 _ => Err(format!("Unknown config key: {key}").into()),
523 }
524 }
525
526 fn set_config(&self, _key: &str, _value: String) -> Result<String, BoxedError> {
527 Ok("".to_owned())
528 }
529
530 async fn next_notice(self: &Arc<Self>) -> String {
531 std::future::pending().await
532 }
533
534 fn transaction_status(&self) -> TransactionStatus {
535 TransactionStatus::Idle
536 }
537
538 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
539 let exec_context = Arc::new(ExecContext {
540 running_sql: sql,
541 last_instant: Instant::now(),
542 last_idle_instant: Default::default(),
543 });
544 ExecContextGuard::new(exec_context)
545 }
546
547 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
548 Ok(())
549 }
550 }
551
552 async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
553 let bind_addr = bind_addr.into();
554 let pg_config = pg_config.into();
555
556 let session_mgr = MockSessionManager {};
557 tokio::spawn(async move {
558 pg_serve(
559 &bind_addr,
560 socket2::TcpKeepalive::new(),
561 Arc::new(session_mgr),
562 ConnectionContext {
563 tls_config: None,
564 redact_sql_option_keywords: None,
565 message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
566 .into(),
567 },
568 CancellationToken::new(), )
570 .await
571 });
572 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
574
575 let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
577
578 tokio::spawn(async move {
581 if let Err(e) = connection.await {
582 eprintln!("connection error: {}", e);
583 }
584 });
585
586 let rows = client
587 .simple_query("SELECT ''")
588 .await
589 .expect("Error executing query");
590 assert_eq!(rows.len(), 2);
592
593 let rows = client
594 .query("SELECT ''", &[])
595 .await
596 .expect("Error executing query");
597 assert_eq!(rows.len(), 1);
598 }
599
600 #[tokio::test]
601 async fn test_query_tcp() {
602 do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
603 }
604
605 #[cfg(not(madsim))]
606 #[tokio::test]
607 async fn test_query_unix() {
608 let port: i16 = 10000;
609 let dir = tempfile::TempDir::new().unwrap();
610 let sock = dir.path().join(format!(".s.PGSQL.{port}"));
611
612 do_test_query(
613 format!("unix:{}", sock.to_str().unwrap()),
614 format!("host={} port={}", dir.path().to_str().unwrap(), port),
615 )
616 .await;
617 }
618
619 mod jwt_validation_tests {
620 use std::collections::HashMap;
621 use std::time::{SystemTime, UNIX_EPOCH};
622
623 use base64::Engine;
624 use jsonwebtoken::{Algorithm, EncodingKey, Header};
625 use rsa::pkcs1::EncodeRsaPrivateKey;
626 use rsa::traits::PublicKeyParts;
627 use rsa::{RsaPrivateKey, RsaPublicKey};
628 use serde_json::json;
629
630 use crate::pg_server::{Jwk, Jwks, validate_jwt_with_jwks};
631
632 fn create_test_rsa_keys() -> (RsaPrivateKey, RsaPublicKey) {
633 let mut rng = rand::thread_rng();
634 let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate a key");
635 let public_key = RsaPublicKey::from(&private_key);
636 (private_key, public_key)
637 }
638
639 fn create_test_jwks(public_key: &RsaPublicKey, kid: &str, alg: &str) -> Jwks {
640 let n = base64::engine::general_purpose::URL_SAFE_NO_PAD
641 .encode(public_key.n().to_bytes_be());
642 let e = base64::engine::general_purpose::URL_SAFE_NO_PAD
643 .encode(public_key.e().to_bytes_be());
644
645 Jwks {
646 keys: vec![Jwk {
647 kid: kid.to_owned(),
648 alg: alg.to_owned(),
649 n,
650 e,
651 }],
652 }
653 }
654
655 fn create_jwt_token(
656 private_key: &RsaPrivateKey,
657 kid: &str,
658 algorithm: Algorithm,
659 issuer: &str,
660 audience: Option<&str>,
661 exp: u64,
662 additional_claims: HashMap<String, serde_json::Value>,
663 ) -> String {
664 let mut header = Header::new(algorithm);
665 header.kid = Some(kid.to_owned());
666
667 let mut claims = json!({
668 "iss": issuer,
669 "exp": exp,
670 });
671
672 if let Some(aud) = audience {
673 claims["aud"] = json!(aud);
674 }
675
676 for (key, value) in additional_claims {
677 claims[key] = value;
678 }
679
680 let encoding_key = EncodingKey::from_rsa_pem(
681 private_key
682 .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
683 .unwrap()
684 .as_bytes(),
685 )
686 .unwrap();
687
688 jsonwebtoken::encode(&header, &claims, &encoding_key).unwrap()
689 }
690
691 fn get_future_timestamp() -> u64 {
692 SystemTime::now()
693 .duration_since(UNIX_EPOCH)
694 .unwrap()
695 .as_secs()
696 + 3600 }
698
699 fn get_past_timestamp() -> u64 {
700 SystemTime::now()
701 .duration_since(UNIX_EPOCH)
702 .unwrap()
703 .as_secs()
704 - 3600 }
706
707 #[test]
708 fn test_jwt_with_invalid_audience() {
709 let (private_key, public_key) = create_test_rsa_keys();
710 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
711
712 let metadata = HashMap::new();
713
714 let jwt = create_jwt_token(
715 &private_key,
716 "test-kid",
717 Algorithm::RS256,
718 "https://test-issuer.com",
719 Some("urn:risingwave:cluster:wrong-cluster-id"),
720 get_future_timestamp(),
721 HashMap::new(),
722 );
723
724 let result = validate_jwt_with_jwks(
725 &jwt,
726 &jwks,
727 "https://test-issuer.com",
728 "test-cluster-id",
729 &metadata,
730 );
731
732 let error = result.unwrap_err();
733 assert!(error.to_string().contains("InvalidAudience"));
734 }
735
736 #[test]
737 fn test_jwt_with_missing_audience() {
738 let (private_key, public_key) = create_test_rsa_keys();
739 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
740
741 let metadata = HashMap::new();
742
743 let jwt = create_jwt_token(
744 &private_key,
745 "test-kid",
746 Algorithm::RS256,
747 "https://test-issuer.com",
748 None, get_future_timestamp(),
750 HashMap::new(),
751 );
752
753 let result = validate_jwt_with_jwks(
754 &jwt,
755 &jwks,
756 "https://test-issuer.com",
757 "test-cluster-id",
758 &metadata,
759 );
760
761 let error = result.unwrap_err();
762 assert!(error.to_string().contains("Missing required claim: aud"));
763 }
764
765 #[test]
766 fn test_jwt_with_invalid_issuer() {
767 let (private_key, public_key) = create_test_rsa_keys();
768 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
769
770 let metadata = HashMap::new();
771
772 let jwt = create_jwt_token(
773 &private_key,
774 "test-kid",
775 Algorithm::RS256,
776 "https://wrong-issuer.com",
777 Some("urn:risingwave:cluster:test-cluster-id"),
778 get_future_timestamp(),
779 HashMap::new(),
780 );
781
782 let result = validate_jwt_with_jwks(
783 &jwt,
784 &jwks,
785 "https://test-issuer.com",
786 "test-cluster-id",
787 &metadata,
788 );
789
790 let error = result.unwrap_err();
791 assert!(error.to_string().contains("InvalidIssuer"));
792 }
793
794 #[test]
795 fn test_jwt_with_kid_not_found_in_jwks() {
796 let (private_key, public_key) = create_test_rsa_keys();
797 let jwks = create_test_jwks(&public_key, "different-kid", "RS256");
798
799 let metadata = HashMap::new();
800
801 let jwt = create_jwt_token(
802 &private_key,
803 "missing-kid",
804 Algorithm::RS256,
805 "https://test-issuer.com",
806 Some("urn:risingwave:cluster:test-cluster-id"),
807 get_future_timestamp(),
808 HashMap::new(),
809 );
810
811 let result = validate_jwt_with_jwks(
812 &jwt,
813 &jwks,
814 "https://test-issuer.com",
815 "test-cluster-id",
816 &metadata,
817 );
818
819 let error = result.unwrap_err();
820 assert!(
821 error
822 .to_string()
823 .contains("No matching key found in JWKS for kid: 'missing-kid'")
824 );
825 }
826
827 #[test]
828 fn test_jwt_with_expired_token() {
829 let (private_key, public_key) = create_test_rsa_keys();
830 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
831
832 let metadata = HashMap::new();
833
834 let jwt = create_jwt_token(
835 &private_key,
836 "test-kid",
837 Algorithm::RS256,
838 "https://test-issuer.com",
839 Some("urn:risingwave:cluster:test-cluster-id"),
840 get_past_timestamp(), HashMap::new(),
842 );
843
844 let result = validate_jwt_with_jwks(
845 &jwt,
846 &jwks,
847 "https://test-issuer.com",
848 "test-cluster-id",
849 &metadata,
850 );
851
852 let error = result.unwrap_err();
853 assert!(error.to_string().contains("ExpiredSignature"));
854 }
855
856 #[test]
857 fn test_jwt_with_invalid_signature() {
858 let (_, public_key) = create_test_rsa_keys();
859 let (wrong_private_key, _) = create_test_rsa_keys(); let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
861
862 let metadata = HashMap::new();
863
864 let jwt = create_jwt_token(
866 &wrong_private_key,
867 "test-kid",
868 Algorithm::RS256,
869 "https://test-issuer.com",
870 Some("urn:risingwave:cluster:test-cluster-id"),
871 get_future_timestamp(),
872 HashMap::new(),
873 );
874
875 let result = validate_jwt_with_jwks(
876 &jwt,
877 &jwks,
878 "https://test-issuer.com",
879 "test-cluster-id",
880 &metadata,
881 );
882
883 let error = result.unwrap_err();
884 assert!(error.to_string().contains("InvalidSignature"));
885 }
886
887 #[test]
888 fn test_metadata_validation_success() {
889 let (private_key, public_key) = create_test_rsa_keys();
890 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
891
892 let mut metadata = HashMap::new();
893 metadata.insert("role".to_owned(), "admin".to_owned());
894 metadata.insert("department".to_owned(), "security".to_owned());
895
896 let mut claims = HashMap::new();
897 claims.insert("role".to_owned(), json!("admin"));
898 claims.insert("department".to_owned(), json!("security"));
899 claims.insert("extra_claim".to_owned(), json!("ignored")); let jwt = create_jwt_token(
902 &private_key,
903 "test-kid",
904 Algorithm::RS256,
905 "https://test-issuer.com",
906 Some("urn:risingwave:cluster:test-cluster-id"),
907 get_future_timestamp(),
908 claims,
909 );
910
911 let result = validate_jwt_with_jwks(
912 &jwt,
913 &jwks,
914 "https://test-issuer.com",
915 "test-cluster-id",
916 &metadata,
917 );
918
919 assert!(result.unwrap());
920 }
921
922 #[test]
923 fn test_metadata_validation_failure() {
924 let (private_key, public_key) = create_test_rsa_keys();
925 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
926
927 let mut metadata = HashMap::new();
928 metadata.insert("role".to_owned(), "admin".to_owned());
929 metadata.insert("department".to_owned(), "security".to_owned());
930
931 let mut claims = HashMap::new();
932 claims.insert("role".to_owned(), json!("user")); claims.insert("department".to_owned(), json!("security"));
934
935 let jwt = create_jwt_token(
936 &private_key,
937 "test-kid",
938 Algorithm::RS256,
939 "https://test-issuer.com",
940 Some("urn:risingwave:cluster:test-cluster-id"),
941 get_future_timestamp(),
942 claims,
943 );
944
945 let result = validate_jwt_with_jwks(
946 &jwt,
947 &jwks,
948 "https://test-issuer.com",
949 "test-cluster-id",
950 &metadata,
951 );
952
953 let error = result.unwrap_err();
954 assert_eq!(
955 error.to_string(),
956 "metadata in jwt does not match with metadata declared with user"
957 );
958 }
959 }
960}