1use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::config::HbaEntry;
25use risingwave_common::id::DatabaseId;
26use risingwave_common::types::DataType;
27use risingwave_common::util::runtime::BackgroundShutdownRuntime;
28use risingwave_common::util::tokio_util::sync::CancellationToken;
29use risingwave_sqlparser::ast::Statement;
30use serde::Deserialize;
31use thiserror_ext::AsReport;
32
33use crate::error::{PsqlError, PsqlResult};
34use crate::ldap_auth::LdapAuthenticator;
35use crate::net::{AddressRef, Listener, TcpKeepalive};
36use crate::pg_field_descriptor::PgFieldDescriptor;
37use crate::pg_message::TransactionStatus;
38use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
39use crate::pg_response::{PgResponse, ValuesStream};
40use crate::types::Format;
41
42pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
43type ProcessId = i32;
44type SecretKey = i32;
45pub type SessionId = (ProcessId, SecretKey);
46
47pub trait SessionManager: Send + Sync + 'static {
50 type Error: Into<BoxedError>;
51 type Session: Session<Error = Self::Error>;
52
53 fn create_dummy_session(
56 &self,
57 database_id: DatabaseId,
58 ) -> Result<Arc<Self::Session>, Self::Error>;
59
60 fn connect(
61 &self,
62 database: &str,
63 user_name: &str,
64 peer_addr: AddressRef,
65 ) -> Result<Arc<Self::Session>, Self::Error>;
66
67 fn cancel_queries_in_session(&self, session_id: SessionId);
68
69 fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
70
71 fn end_session(&self, session: &Self::Session);
72
73 fn shutdown(&self) -> impl Future<Output = ()> + Send {
75 async {}
76 }
77}
78
79pub trait Session: Send + Sync {
82 type Error: Into<BoxedError>;
83 type ValuesStream: ValuesStream;
84 type PreparedStatement: Send + Clone + 'static;
85 type Portal: Send + Clone + std::fmt::Display + 'static;
86
87 fn run_one_query(
90 self: Arc<Self>,
91 stmt: Statement,
92 format: Format,
93 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
94
95 fn parse(
96 self: Arc<Self>,
97 sql: Option<Statement>,
98 params_types: Vec<Option<DataType>>,
99 ) -> impl Future<Output = Result<Self::PreparedStatement, Self::Error>> + Send;
100
101 fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
105
106 fn bind(
107 self: Arc<Self>,
108 prepare_statement: Self::PreparedStatement,
109 params: Vec<Option<Bytes>>,
110 param_formats: Vec<Format>,
111 result_formats: Vec<Format>,
112 ) -> Result<Self::Portal, Self::Error>;
113
114 fn execute(
115 self: Arc<Self>,
116 portal: Self::Portal,
117 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
118
119 fn describe_statement(
120 self: Arc<Self>,
121 prepare_statement: Self::PreparedStatement,
122 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error>;
123
124 fn describe_portal(
125 self: Arc<Self>,
126 portal: Self::Portal,
127 ) -> Result<Vec<PgFieldDescriptor>, Self::Error>;
128
129 fn user_authenticator(&self) -> &UserAuthenticator;
130
131 fn id(&self) -> SessionId;
132
133 fn get_config(&self, key: &str) -> Result<String, Self::Error>;
134
135 fn set_config(&self, key: &str, value: String) -> Result<String, Self::Error>;
136
137 fn transaction_status(&self) -> TransactionStatus;
138
139 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
140
141 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
142
143 fn user(&self) -> String;
144}
145
146pub struct ExecContext {
149 pub running_sql: Arc<str>,
150 pub last_instant: Instant,
152 pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
154}
155
156pub struct ExecContextGuard(#[expect(dead_code)] Arc<ExecContext>);
159
160impl ExecContextGuard {
161 pub fn new(exec_context: Arc<ExecContext>) -> Self {
162 Self(exec_context)
163 }
164}
165
166impl Drop for ExecContext {
167 fn drop(&mut self) {
168 *self.last_idle_instant.lock() = Some(Instant::now());
169 }
170}
171
172#[derive(Debug, Clone)]
173pub enum UserAuthenticator {
174 None,
176 ClearText(Vec<u8>),
178 Md5WithSalt {
180 encrypted_password: Vec<u8>,
181 salt: [u8; 4],
182 },
183 OAuth {
184 metadata: HashMap<String, String>,
185 cluster_id: String,
186 },
187 Ldap(String, HbaEntry),
188}
189
190#[derive(Debug, Deserialize)]
194struct Jwks {
195 keys: Vec<Jwk>,
196}
197
198#[derive(Debug, Deserialize)]
201struct Jwk {
202 kid: String, alg: Option<String>, n: String, e: String, }
207
208const ALLOWED_JWT_ALGORITHMS: &[Algorithm] = &[
217 Algorithm::RS256,
218 Algorithm::RS384,
219 Algorithm::RS512,
220 Algorithm::PS256,
221 Algorithm::PS384,
222 Algorithm::PS512,
223];
224
225async fn validate_jwt(
226 jwt: &str,
227 jwks_url: &str,
228 issuer: &str,
229 cluster_id: &str,
230 metadata: &HashMap<String, String>,
231) -> Result<bool, BoxedError> {
232 let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
233 validate_jwt_with_jwks(jwt, &jwks, issuer, cluster_id, metadata)
234}
235
236fn audience_from_cluster_id(cluster_id: &str) -> String {
237 format!("urn:risingwave:cluster:{}", cluster_id)
238}
239
240fn validate_jwt_with_jwks(
241 jwt: &str,
242 jwks: &Jwks,
243 issuer: &str,
244 cluster_id: &str,
245 metadata: &HashMap<String, String>,
246) -> Result<bool, BoxedError> {
247 let header = decode_header(jwt)?;
248
249 let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
251 let jwk = jwks
252 .keys
253 .iter()
254 .find(|k| k.kid == kid)
255 .ok_or(format!("No matching key found in JWKS for kid: '{}'", kid))?;
256
257 let alg = match jwk.alg.as_deref() {
266 Some(jwk_alg) => {
267 let jwk_alg = Algorithm::from_str(jwk_alg)?;
268 if jwk_alg != header.alg {
269 return Err("alg in jwt header does not match with alg in jwk".into());
270 }
271 jwk_alg
272 }
273 None => header.alg,
274 };
275 if !ALLOWED_JWT_ALGORITHMS.contains(&alg) {
276 return Err(format!("JWT alg {:?} is not allowed", alg).into());
277 }
278
279 let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
281 let mut validation = Validation::new(alg);
282 validation.set_issuer(&[issuer]);
283 validation.set_audience(&[audience_from_cluster_id(cluster_id)]); validation.set_required_spec_claims(&["exp", "iss", "aud"]);
285 let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
286
287 if !metadata.iter().all(
289 |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
290 ) {
291 return Err("metadata in jwt does not match with metadata declared with user".into());
292 }
293 Ok(true)
294}
295
296impl UserAuthenticator {
297 pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
298 let success = match self {
299 UserAuthenticator::None => true,
300 UserAuthenticator::ClearText(text) => password == text,
301 UserAuthenticator::Md5WithSalt {
302 encrypted_password, ..
303 } => encrypted_password == password,
304 UserAuthenticator::OAuth {
305 metadata,
306 cluster_id,
307 } => {
308 let mut metadata = metadata.clone();
309 let jwks_url = metadata.remove("jwks_url").unwrap();
310 let issuer = metadata.remove("issuer").unwrap();
311 validate_jwt(
312 &String::from_utf8_lossy(password),
313 &jwks_url,
314 &issuer,
315 cluster_id,
316 &metadata,
317 )
318 .await
319 .map_err(PsqlError::StartupError)?
320 }
321 UserAuthenticator::Ldap(user_name, hba_entry) => {
322 let ldap_auth = LdapAuthenticator::new(hba_entry)?;
323 let password_str = String::from_utf8_lossy(password).into_owned();
325 ldap_auth.authenticate(user_name, &password_str).await?
326 }
327 };
328 if !success {
329 return Err(PsqlError::PasswordError);
330 }
331 Ok(())
332 }
333}
334
335pub async fn pg_serve(
339 addr: &str,
340 tcp_keepalive: TcpKeepalive,
341 session_mgr: Arc<impl SessionManager>,
342 context: ConnectionContext,
343 shutdown: CancellationToken,
344) -> Result<(), BoxedError> {
345 let listener = Listener::bind(addr).await?;
346 tracing::info!(addr, "server started");
347
348 let acceptor_runtime = BackgroundShutdownRuntime::from({
349 let mut builder = tokio::runtime::Builder::new_multi_thread();
350 builder.worker_threads(1);
351 builder
352 .thread_name("rw-acceptor")
353 .enable_all()
354 .build()
355 .unwrap()
356 });
357
358 #[cfg(not(madsim))]
359 let worker_runtime = tokio::runtime::Handle::current();
360 #[cfg(madsim)]
361 let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
362 let session_mgr_clone = session_mgr.clone();
363 let f = async move {
364 loop {
365 let conn_ret = listener.accept(&tcp_keepalive).await;
366 match conn_ret {
367 Ok((stream, peer_addr)) => {
368 tracing::info!(%peer_addr, "accept connection");
369 worker_runtime.spawn(handle_connection(
370 stream,
371 session_mgr_clone.clone(),
372 Arc::new(peer_addr),
373 context.clone(),
374 ));
375 }
376
377 Err(e) => {
378 tracing::error!(error = %e.as_report(), "failed to accept connection",);
379 }
380 }
381 }
382 };
383 acceptor_runtime.spawn(f);
384
385 shutdown.cancelled().await;
387
388 drop(acceptor_runtime);
390 session_mgr.shutdown().await;
392
393 Ok(())
394}
395
396pub async fn handle_connection<S, SM>(
397 stream: S,
398 session_mgr: Arc<SM>,
399 peer_addr: AddressRef,
400 context: ConnectionContext,
401) where
402 S: PgByteStream,
403 SM: SessionManager,
404{
405 PgProtocol::new(stream, session_mgr, peer_addr, context)
406 .run()
407 .await;
408}
409#[cfg(test)]
410mod tests {
411 use std::sync::Arc;
412 use std::time::Instant;
413
414 use bytes::Bytes;
415 use futures::StreamExt;
416 use futures::stream::BoxStream;
417 use risingwave_common::id::DatabaseId;
418 use risingwave_common::types::DataType;
419 use risingwave_common::util::tokio_util::sync::CancellationToken;
420 use risingwave_sqlparser::ast::Statement;
421 use tokio_postgres::NoTls;
422
423 use crate::error::PsqlResult;
424 use crate::memory_manager::MessageMemoryManager;
425 use crate::pg_field_descriptor::PgFieldDescriptor;
426 use crate::pg_message::TransactionStatus;
427 use crate::pg_protocol::ConnectionContext;
428 use crate::pg_response::{PgResponse, RowSetResult, StatementType};
429 use crate::pg_server::{
430 BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
431 UserAuthenticator, pg_serve,
432 };
433 use crate::types;
434 use crate::types::Row;
435
436 struct MockSessionManager {}
437 struct MockSession {}
438
439 impl SessionManager for MockSessionManager {
440 type Error = BoxedError;
441 type Session = MockSession;
442
443 fn create_dummy_session(
444 &self,
445 _database_id: DatabaseId,
446 ) -> Result<Arc<Self::Session>, Self::Error> {
447 unimplemented!()
448 }
449
450 fn connect(
451 &self,
452 _database: &str,
453 _user_name: &str,
454 _peer_addr: crate::net::AddressRef,
455 ) -> Result<Arc<Self::Session>, Self::Error> {
456 Ok(Arc::new(MockSession {}))
457 }
458
459 fn cancel_queries_in_session(&self, _session_id: SessionId) {
460 todo!()
461 }
462
463 fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
464 todo!()
465 }
466
467 fn end_session(&self, _session: &Self::Session) {}
468 }
469
470 impl Session for MockSession {
471 type Error = BoxedError;
472 type Portal = String;
473 type PreparedStatement = String;
474 type ValuesStream = BoxStream<'static, RowSetResult>;
475
476 async fn run_one_query(
477 self: Arc<Self>,
478 _stmt: Statement,
479 _format: types::Format,
480 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
481 Ok(PgResponse::builder(StatementType::SELECT)
482 .values(
483 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
484 .boxed(),
485 vec![
486 PgFieldDescriptor::new("".to_owned(), 1043, -1);
489 1
490 ],
491 )
492 .into())
493 }
494
495 async fn parse(
496 self: Arc<Self>,
497 _sql: Option<Statement>,
498 _params_types: Vec<Option<DataType>>,
499 ) -> Result<String, Self::Error> {
500 Ok(String::new())
501 }
502
503 fn bind(
504 self: Arc<Self>,
505 _prepare_statement: String,
506 _params: Vec<Option<Bytes>>,
507 _param_formats: Vec<types::Format>,
508 _result_formats: Vec<types::Format>,
509 ) -> Result<String, Self::Error> {
510 Ok(String::new())
511 }
512
513 async fn execute(
514 self: Arc<Self>,
515 _portal: String,
516 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
517 Ok(PgResponse::builder(StatementType::SELECT)
518 .values(
519 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
520 .boxed(),
521 vec![
522 PgFieldDescriptor::new("".to_owned(), 1043, -1);
525 1
526 ],
527 )
528 .into())
529 }
530
531 fn describe_statement(
532 self: Arc<Self>,
533 _statement: String,
534 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error> {
535 Ok((
536 vec![],
537 vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
538 ))
539 }
540
541 fn describe_portal(
542 self: Arc<Self>,
543 _portal: String,
544 ) -> Result<Vec<PgFieldDescriptor>, Self::Error> {
545 Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
546 }
547
548 fn user_authenticator(&self) -> &UserAuthenticator {
549 &UserAuthenticator::None
550 }
551
552 fn id(&self) -> SessionId {
553 (0, 0)
554 }
555
556 fn get_config(&self, key: &str) -> Result<String, Self::Error> {
557 match key {
558 "timezone" => Ok("UTC".to_owned()),
559 _ => Err(format!("Unknown config key: {key}").into()),
560 }
561 }
562
563 fn set_config(&self, _key: &str, _value: String) -> Result<String, Self::Error> {
564 Ok("".to_owned())
565 }
566
567 async fn next_notice(self: &Arc<Self>) -> String {
568 std::future::pending().await
569 }
570
571 fn transaction_status(&self) -> TransactionStatus {
572 TransactionStatus::Idle
573 }
574
575 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
576 let exec_context = Arc::new(ExecContext {
577 running_sql: sql,
578 last_instant: Instant::now(),
579 last_idle_instant: Default::default(),
580 });
581 ExecContextGuard::new(exec_context)
582 }
583
584 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
585 Ok(())
586 }
587
588 fn user(&self) -> String {
589 "mock".to_owned()
590 }
591 }
592
593 async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
594 let bind_addr = bind_addr.into();
595 let pg_config = pg_config.into();
596
597 let session_mgr = MockSessionManager {};
598 tokio::spawn(async move {
599 pg_serve(
600 &bind_addr,
601 socket2::TcpKeepalive::new(),
602 Arc::new(session_mgr),
603 ConnectionContext {
604 tls_config: None,
605 redact_sql_option_keywords: None,
606 message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
607 .into(),
608 },
609 CancellationToken::new(), )
611 .await
612 });
613 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
615
616 let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
618
619 tokio::spawn(async move {
622 if let Err(e) = connection.await {
623 eprintln!("connection error: {}", e);
624 }
625 });
626
627 let rows = client
628 .simple_query("SELECT ''")
629 .await
630 .expect("Error executing query");
631 assert_eq!(rows.len(), 2);
633
634 let rows = client
635 .query("SELECT ''", &[])
636 .await
637 .expect("Error executing query");
638 assert_eq!(rows.len(), 1);
639 }
640
641 #[tokio::test]
642 async fn test_query_tcp() {
643 do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
644 }
645
646 #[cfg(not(madsim))]
647 #[tokio::test]
648 async fn test_query_unix() {
649 let port: i16 = 10000;
650 let dir = tempfile::TempDir::new().unwrap();
651 let sock = dir.path().join(format!(".s.PGSQL.{port}"));
652
653 do_test_query(
654 format!("unix:{}", sock.to_str().unwrap()),
655 format!("host={} port={}", dir.path().to_str().unwrap(), port),
656 )
657 .await;
658 }
659
660 mod jwt_validation_tests {
661 use std::collections::HashMap;
662 use std::time::{SystemTime, UNIX_EPOCH};
663
664 use base64::Engine;
665 use jsonwebtoken::{Algorithm, EncodingKey, Header};
666 use rsa::pkcs1::EncodeRsaPrivateKey;
667 use rsa::traits::PublicKeyParts;
668 use rsa::{RsaPrivateKey, RsaPublicKey};
669 use serde_json::json;
670
671 use crate::pg_server::{Jwk, Jwks, validate_jwt_with_jwks};
672
673 fn create_test_rsa_keys() -> (RsaPrivateKey, RsaPublicKey) {
674 let mut rng = rand::thread_rng();
675 let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate a key");
676 let public_key = RsaPublicKey::from(&private_key);
677 (private_key, public_key)
678 }
679
680 fn create_test_jwks(public_key: &RsaPublicKey, kid: &str, alg: Option<&str>) -> Jwks {
681 let n = base64::engine::general_purpose::URL_SAFE_NO_PAD
682 .encode(public_key.n().to_bytes_be());
683 let e = base64::engine::general_purpose::URL_SAFE_NO_PAD
684 .encode(public_key.e().to_bytes_be());
685
686 Jwks {
687 keys: vec![Jwk {
688 kid: kid.to_owned(),
689 alg: alg.map(ToOwned::to_owned),
690 n,
691 e,
692 }],
693 }
694 }
695
696 fn create_jwt_token(
697 private_key: &RsaPrivateKey,
698 kid: &str,
699 algorithm: Algorithm,
700 issuer: &str,
701 audience: Option<&str>,
702 exp: u64,
703 additional_claims: HashMap<String, serde_json::Value>,
704 ) -> String {
705 let mut header = Header::new(algorithm);
706 header.kid = Some(kid.to_owned());
707
708 let mut claims = json!({
709 "iss": issuer,
710 "exp": exp,
711 });
712
713 if let Some(aud) = audience {
714 claims["aud"] = json!(aud);
715 }
716
717 for (key, value) in additional_claims {
718 claims[key] = value;
719 }
720
721 let encoding_key = EncodingKey::from_rsa_pem(
722 private_key
723 .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
724 .unwrap()
725 .as_bytes(),
726 )
727 .unwrap();
728
729 jsonwebtoken::encode(&header, &claims, &encoding_key).unwrap()
730 }
731
732 fn get_future_timestamp() -> u64 {
733 SystemTime::now()
734 .duration_since(UNIX_EPOCH)
735 .unwrap()
736 .as_secs()
737 + 3600 }
739
740 fn get_past_timestamp() -> u64 {
741 SystemTime::now()
742 .duration_since(UNIX_EPOCH)
743 .unwrap()
744 .as_secs()
745 - 3600 }
747
748 #[test]
749 fn test_jwt_with_invalid_audience() {
750 let (private_key, public_key) = create_test_rsa_keys();
751 let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
752
753 let metadata = HashMap::new();
754
755 let jwt = create_jwt_token(
756 &private_key,
757 "test-kid",
758 Algorithm::RS256,
759 "https://test-issuer.com",
760 Some("urn:risingwave:cluster:wrong-cluster-id"),
761 get_future_timestamp(),
762 HashMap::new(),
763 );
764
765 let result = validate_jwt_with_jwks(
766 &jwt,
767 &jwks,
768 "https://test-issuer.com",
769 "test-cluster-id",
770 &metadata,
771 );
772
773 let error = result.unwrap_err();
774 assert!(error.to_string().contains("InvalidAudience"));
775 }
776
777 #[test]
778 fn test_jwt_with_missing_audience() {
779 let (private_key, public_key) = create_test_rsa_keys();
780 let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
781
782 let metadata = HashMap::new();
783
784 let jwt = create_jwt_token(
785 &private_key,
786 "test-kid",
787 Algorithm::RS256,
788 "https://test-issuer.com",
789 None, get_future_timestamp(),
791 HashMap::new(),
792 );
793
794 let result = validate_jwt_with_jwks(
795 &jwt,
796 &jwks,
797 "https://test-issuer.com",
798 "test-cluster-id",
799 &metadata,
800 );
801
802 let error = result.unwrap_err();
803 assert!(error.to_string().contains("Missing required claim: aud"));
804 }
805
806 #[test]
807 fn test_jwt_with_invalid_issuer() {
808 let (private_key, public_key) = create_test_rsa_keys();
809 let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
810
811 let metadata = HashMap::new();
812
813 let jwt = create_jwt_token(
814 &private_key,
815 "test-kid",
816 Algorithm::RS256,
817 "https://wrong-issuer.com",
818 Some("urn:risingwave:cluster:test-cluster-id"),
819 get_future_timestamp(),
820 HashMap::new(),
821 );
822
823 let result = validate_jwt_with_jwks(
824 &jwt,
825 &jwks,
826 "https://test-issuer.com",
827 "test-cluster-id",
828 &metadata,
829 );
830
831 let error = result.unwrap_err();
832 assert!(error.to_string().contains("InvalidIssuer"));
833 }
834
835 #[test]
836 fn test_jwt_with_kid_not_found_in_jwks() {
837 let (private_key, public_key) = create_test_rsa_keys();
838 let jwks = create_test_jwks(&public_key, "different-kid", Some("RS256"));
839
840 let metadata = HashMap::new();
841
842 let jwt = create_jwt_token(
843 &private_key,
844 "missing-kid",
845 Algorithm::RS256,
846 "https://test-issuer.com",
847 Some("urn:risingwave:cluster:test-cluster-id"),
848 get_future_timestamp(),
849 HashMap::new(),
850 );
851
852 let result = validate_jwt_with_jwks(
853 &jwt,
854 &jwks,
855 "https://test-issuer.com",
856 "test-cluster-id",
857 &metadata,
858 );
859
860 let error = result.unwrap_err();
861 assert!(
862 error
863 .to_string()
864 .contains("No matching key found in JWKS for kid: 'missing-kid'")
865 );
866 }
867
868 #[test]
869 fn test_jwt_with_expired_token() {
870 let (private_key, public_key) = create_test_rsa_keys();
871 let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
872
873 let metadata = HashMap::new();
874
875 let jwt = create_jwt_token(
876 &private_key,
877 "test-kid",
878 Algorithm::RS256,
879 "https://test-issuer.com",
880 Some("urn:risingwave:cluster:test-cluster-id"),
881 get_past_timestamp(), HashMap::new(),
883 );
884
885 let result = validate_jwt_with_jwks(
886 &jwt,
887 &jwks,
888 "https://test-issuer.com",
889 "test-cluster-id",
890 &metadata,
891 );
892
893 let error = result.unwrap_err();
894 assert!(error.to_string().contains("ExpiredSignature"));
895 }
896
897 #[test]
898 fn test_jwt_with_invalid_signature() {
899 let (_, public_key) = create_test_rsa_keys();
900 let (wrong_private_key, _) = create_test_rsa_keys(); let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
902
903 let metadata = HashMap::new();
904
905 let jwt = create_jwt_token(
907 &wrong_private_key,
908 "test-kid",
909 Algorithm::RS256,
910 "https://test-issuer.com",
911 Some("urn:risingwave:cluster:test-cluster-id"),
912 get_future_timestamp(),
913 HashMap::new(),
914 );
915
916 let result = validate_jwt_with_jwks(
917 &jwt,
918 &jwks,
919 "https://test-issuer.com",
920 "test-cluster-id",
921 &metadata,
922 );
923
924 let error = result.unwrap_err();
925 assert!(error.to_string().contains("InvalidSignature"));
926 }
927
928 #[test]
929 fn test_metadata_validation_success() {
930 let (private_key, public_key) = create_test_rsa_keys();
931 let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
932
933 let mut metadata = HashMap::new();
934 metadata.insert("role".to_owned(), "admin".to_owned());
935 metadata.insert("department".to_owned(), "security".to_owned());
936
937 let mut claims = HashMap::new();
938 claims.insert("role".to_owned(), json!("admin"));
939 claims.insert("department".to_owned(), json!("security"));
940 claims.insert("extra_claim".to_owned(), json!("ignored")); let jwt = create_jwt_token(
943 &private_key,
944 "test-kid",
945 Algorithm::RS256,
946 "https://test-issuer.com",
947 Some("urn:risingwave:cluster:test-cluster-id"),
948 get_future_timestamp(),
949 claims,
950 );
951
952 let result = validate_jwt_with_jwks(
953 &jwt,
954 &jwks,
955 "https://test-issuer.com",
956 "test-cluster-id",
957 &metadata,
958 );
959
960 assert!(result.unwrap());
961 }
962
963 #[test]
964 fn test_metadata_validation_failure() {
965 let (private_key, public_key) = create_test_rsa_keys();
966 let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
967
968 let mut metadata = HashMap::new();
969 metadata.insert("role".to_owned(), "admin".to_owned());
970 metadata.insert("department".to_owned(), "security".to_owned());
971
972 let mut claims = HashMap::new();
973 claims.insert("role".to_owned(), json!("user")); claims.insert("department".to_owned(), json!("security"));
975
976 let jwt = create_jwt_token(
977 &private_key,
978 "test-kid",
979 Algorithm::RS256,
980 "https://test-issuer.com",
981 Some("urn:risingwave:cluster:test-cluster-id"),
982 get_future_timestamp(),
983 claims,
984 );
985
986 let result = validate_jwt_with_jwks(
987 &jwt,
988 &jwks,
989 "https://test-issuer.com",
990 "test-cluster-id",
991 &metadata,
992 );
993
994 let error = result.unwrap_err();
995 assert_eq!(
996 error.to_string(),
997 "metadata in jwt does not match with metadata declared with user"
998 );
999 }
1000
1001 #[test]
1002 fn test_jwt_with_jwk_missing_alg_succeeds() {
1003 let (private_key, public_key) = create_test_rsa_keys();
1004 let jwks = create_test_jwks(&public_key, "test-kid", None);
1005
1006 let jwt = create_jwt_token(
1007 &private_key,
1008 "test-kid",
1009 Algorithm::RS256,
1010 "https://test-issuer.com",
1011 Some("urn:risingwave:cluster:test-cluster-id"),
1012 get_future_timestamp(),
1013 HashMap::new(),
1014 );
1015
1016 let result = validate_jwt_with_jwks(
1017 &jwt,
1018 &jwks,
1019 "https://test-issuer.com",
1020 "test-cluster-id",
1021 &HashMap::new(),
1022 );
1023
1024 assert!(result.unwrap());
1025 }
1026
1027 #[test]
1028 fn test_jwt_with_jwk_missing_alg_rejects_disallowed_header_alg() {
1029 let (_, public_key) = create_test_rsa_keys();
1030 let jwks = create_test_jwks(&public_key, "test-kid", None);
1031
1032 let mut header = Header::new(Algorithm::HS256);
1037 header.kid = Some("test-kid".to_owned());
1038 let claims = json!({
1039 "iss": "https://test-issuer.com",
1040 "aud": "urn:risingwave:cluster:test-cluster-id",
1041 "exp": get_future_timestamp(),
1042 });
1043 let jwt = jsonwebtoken::encode(
1044 &header,
1045 &claims,
1046 &EncodingKey::from_secret(b"attacker-chosen"),
1047 )
1048 .unwrap();
1049
1050 let result = validate_jwt_with_jwks(
1051 &jwt,
1052 &jwks,
1053 "https://test-issuer.com",
1054 "test-cluster-id",
1055 &HashMap::new(),
1056 );
1057
1058 let error = result.unwrap_err();
1059 assert!(
1060 error.to_string().contains("is not allowed"),
1061 "unexpected error: {}",
1062 error
1063 );
1064 }
1065
1066 #[test]
1067 fn test_jwt_alg_mismatch_between_header_and_jwk() {
1068 let (private_key, public_key) = create_test_rsa_keys();
1069 let jwks = create_test_jwks(&public_key, "test-kid", Some("RS384"));
1071
1072 let jwt = create_jwt_token(
1073 &private_key,
1074 "test-kid",
1075 Algorithm::RS256,
1076 "https://test-issuer.com",
1077 Some("urn:risingwave:cluster:test-cluster-id"),
1078 get_future_timestamp(),
1079 HashMap::new(),
1080 );
1081
1082 let result = validate_jwt_with_jwks(
1083 &jwt,
1084 &jwks,
1085 "https://test-issuer.com",
1086 "test-cluster-id",
1087 &HashMap::new(),
1088 );
1089
1090 let error = result.unwrap_err();
1091 assert_eq!(
1092 error.to_string(),
1093 "alg in jwt header does not match with alg in jwk"
1094 );
1095 }
1096 }
1097}