1use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::config::HbaEntry;
25use risingwave_common::id::DatabaseId;
26use risingwave_common::types::DataType;
27use risingwave_common::util::runtime::BackgroundShutdownRuntime;
28use risingwave_common::util::tokio_util::sync::CancellationToken;
29use risingwave_sqlparser::ast::Statement;
30use serde::Deserialize;
31use thiserror_ext::AsReport;
32
33use crate::error::{PsqlError, PsqlResult};
34use crate::ldap_auth::LdapAuthenticator;
35use crate::net::{AddressRef, Listener, TcpKeepalive};
36use crate::pg_field_descriptor::PgFieldDescriptor;
37use crate::pg_message::TransactionStatus;
38use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
39use crate::pg_response::{PgResponse, ValuesStream};
40use crate::types::Format;
41
42pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
43type ProcessId = i32;
44type SecretKey = i32;
45pub type SessionId = (ProcessId, SecretKey);
46
47pub trait SessionManager: Send + Sync + 'static {
50 type Error: Into<BoxedError>;
51 type Session: Session<Error = Self::Error>;
52
53 fn create_dummy_session(
56 &self,
57 database_id: DatabaseId,
58 user_id: u32,
59 ) -> Result<Arc<Self::Session>, Self::Error>;
60
61 fn connect(
62 &self,
63 database: &str,
64 user_name: &str,
65 peer_addr: AddressRef,
66 ) -> Result<Arc<Self::Session>, Self::Error>;
67
68 fn cancel_queries_in_session(&self, session_id: SessionId);
69
70 fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
71
72 fn end_session(&self, session: &Self::Session);
73
74 fn shutdown(&self) -> impl Future<Output = ()> + Send {
76 async {}
77 }
78}
79
80pub trait Session: Send + Sync {
83 type Error: Into<BoxedError>;
84 type ValuesStream: ValuesStream;
85 type PreparedStatement: Send + Clone + 'static;
86 type Portal: Send + Clone + std::fmt::Display + 'static;
87
88 fn run_one_query(
91 self: Arc<Self>,
92 stmt: Statement,
93 format: Format,
94 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
95
96 fn parse(
97 self: Arc<Self>,
98 sql: Option<Statement>,
99 params_types: Vec<Option<DataType>>,
100 ) -> impl Future<Output = Result<Self::PreparedStatement, Self::Error>> + Send;
101
102 fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
106
107 fn bind(
108 self: Arc<Self>,
109 prepare_statement: Self::PreparedStatement,
110 params: Vec<Option<Bytes>>,
111 param_formats: Vec<Format>,
112 result_formats: Vec<Format>,
113 ) -> Result<Self::Portal, Self::Error>;
114
115 fn execute(
116 self: Arc<Self>,
117 portal: Self::Portal,
118 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
119
120 fn describe_statement(
121 self: Arc<Self>,
122 prepare_statement: Self::PreparedStatement,
123 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error>;
124
125 fn describe_portal(
126 self: Arc<Self>,
127 portal: Self::Portal,
128 ) -> Result<Vec<PgFieldDescriptor>, Self::Error>;
129
130 fn user_authenticator(&self) -> &UserAuthenticator;
131
132 fn id(&self) -> SessionId;
133
134 fn get_config(&self, key: &str) -> Result<String, Self::Error>;
135
136 fn set_config(&self, key: &str, value: String) -> Result<String, Self::Error>;
137
138 fn transaction_status(&self) -> TransactionStatus;
139
140 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
141
142 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
143}
144
145pub struct ExecContext {
148 pub running_sql: Arc<str>,
149 pub last_instant: Instant,
151 pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
153}
154
155pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
158
159impl ExecContextGuard {
160 pub fn new(exec_context: Arc<ExecContext>) -> Self {
161 Self(exec_context)
162 }
163}
164
165impl Drop for ExecContext {
166 fn drop(&mut self) {
167 *self.last_idle_instant.lock() = Some(Instant::now());
168 }
169}
170
171#[derive(Debug, Clone)]
172pub enum UserAuthenticator {
173 None,
175 ClearText(Vec<u8>),
177 Md5WithSalt {
179 encrypted_password: Vec<u8>,
180 salt: [u8; 4],
181 },
182 OAuth {
183 metadata: HashMap<String, String>,
184 cluster_id: String,
185 },
186 Ldap(String, HbaEntry),
187}
188
189#[derive(Debug, Deserialize)]
193struct Jwks {
194 keys: Vec<Jwk>,
195}
196
197#[derive(Debug, Deserialize)]
200struct Jwk {
201 kid: String, alg: String, n: String, e: String, }
206
207async fn validate_jwt(
208 jwt: &str,
209 jwks_url: &str,
210 issuer: &str,
211 cluster_id: &str,
212 metadata: &HashMap<String, String>,
213) -> Result<bool, BoxedError> {
214 let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
215 validate_jwt_with_jwks(jwt, &jwks, issuer, cluster_id, metadata)
216}
217
218fn audience_from_cluster_id(cluster_id: &str) -> String {
219 format!("urn:risingwave:cluster:{}", cluster_id)
220}
221
222fn validate_jwt_with_jwks(
223 jwt: &str,
224 jwks: &Jwks,
225 issuer: &str,
226 cluster_id: &str,
227 metadata: &HashMap<String, String>,
228) -> Result<bool, BoxedError> {
229 let header = decode_header(jwt)?;
230
231 let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
233 let jwk = jwks
234 .keys
235 .iter()
236 .find(|k| k.kid == kid)
237 .ok_or(format!("No matching key found in JWKS for kid: '{}'", kid))?;
238
239 if Algorithm::from_str(&jwk.alg)? != header.alg {
241 return Err("alg in jwt header does not match with alg in jwk".into());
242 }
243
244 let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
246 let mut validation = Validation::new(header.alg);
247 validation.set_issuer(&[issuer]);
248 validation.set_audience(&[audience_from_cluster_id(cluster_id)]); validation.set_required_spec_claims(&["exp", "iss", "aud"]);
250 let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
251
252 if !metadata.iter().all(
254 |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
255 ) {
256 return Err("metadata in jwt does not match with metadata declared with user".into());
257 }
258 Ok(true)
259}
260
261impl UserAuthenticator {
262 pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
263 let success = match self {
264 UserAuthenticator::None => true,
265 UserAuthenticator::ClearText(text) => password == text,
266 UserAuthenticator::Md5WithSalt {
267 encrypted_password, ..
268 } => encrypted_password == password,
269 UserAuthenticator::OAuth {
270 metadata,
271 cluster_id,
272 } => {
273 let mut metadata = metadata.clone();
274 let jwks_url = metadata.remove("jwks_url").unwrap();
275 let issuer = metadata.remove("issuer").unwrap();
276 validate_jwt(
277 &String::from_utf8_lossy(password),
278 &jwks_url,
279 &issuer,
280 cluster_id,
281 &metadata,
282 )
283 .await
284 .map_err(PsqlError::StartupError)?
285 }
286 UserAuthenticator::Ldap(user_name, hba_entry) => {
287 let ldap_auth = LdapAuthenticator::new(hba_entry)?;
288 let password_str = String::from_utf8_lossy(password).into_owned();
290 ldap_auth.authenticate(user_name, &password_str).await?
291 }
292 };
293 if !success {
294 return Err(PsqlError::PasswordError);
295 }
296 Ok(())
297 }
298}
299
300pub async fn pg_serve(
304 addr: &str,
305 tcp_keepalive: TcpKeepalive,
306 session_mgr: Arc<impl SessionManager>,
307 context: ConnectionContext,
308 shutdown: CancellationToken,
309) -> Result<(), BoxedError> {
310 let listener = Listener::bind(addr).await?;
311 tracing::info!(addr, "server started");
312
313 let acceptor_runtime = BackgroundShutdownRuntime::from({
314 let mut builder = tokio::runtime::Builder::new_multi_thread();
315 builder.worker_threads(1);
316 builder
317 .thread_name("rw-acceptor")
318 .enable_all()
319 .build()
320 .unwrap()
321 });
322
323 #[cfg(not(madsim))]
324 let worker_runtime = tokio::runtime::Handle::current();
325 #[cfg(madsim)]
326 let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
327 let session_mgr_clone = session_mgr.clone();
328 let f = async move {
329 loop {
330 let conn_ret = listener.accept(&tcp_keepalive).await;
331 match conn_ret {
332 Ok((stream, peer_addr)) => {
333 tracing::info!(%peer_addr, "accept connection");
334 worker_runtime.spawn(handle_connection(
335 stream,
336 session_mgr_clone.clone(),
337 Arc::new(peer_addr),
338 context.clone(),
339 ));
340 }
341
342 Err(e) => {
343 tracing::error!(error = %e.as_report(), "failed to accept connection",);
344 }
345 }
346 }
347 };
348 acceptor_runtime.spawn(f);
349
350 shutdown.cancelled().await;
352
353 drop(acceptor_runtime);
355 session_mgr.shutdown().await;
357
358 Ok(())
359}
360
361pub async fn handle_connection<S, SM>(
362 stream: S,
363 session_mgr: Arc<SM>,
364 peer_addr: AddressRef,
365 context: ConnectionContext,
366) where
367 S: PgByteStream,
368 SM: SessionManager,
369{
370 PgProtocol::new(stream, session_mgr, peer_addr, context)
371 .run()
372 .await;
373}
374#[cfg(test)]
375mod tests {
376 use std::sync::Arc;
377 use std::time::Instant;
378
379 use bytes::Bytes;
380 use futures::StreamExt;
381 use futures::stream::BoxStream;
382 use risingwave_common::id::DatabaseId;
383 use risingwave_common::types::DataType;
384 use risingwave_common::util::tokio_util::sync::CancellationToken;
385 use risingwave_sqlparser::ast::Statement;
386 use tokio_postgres::NoTls;
387
388 use crate::error::PsqlResult;
389 use crate::memory_manager::MessageMemoryManager;
390 use crate::pg_field_descriptor::PgFieldDescriptor;
391 use crate::pg_message::TransactionStatus;
392 use crate::pg_protocol::ConnectionContext;
393 use crate::pg_response::{PgResponse, RowSetResult, StatementType};
394 use crate::pg_server::{
395 BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
396 UserAuthenticator, pg_serve,
397 };
398 use crate::types;
399 use crate::types::Row;
400
401 struct MockSessionManager {}
402 struct MockSession {}
403
404 impl SessionManager for MockSessionManager {
405 type Error = BoxedError;
406 type Session = MockSession;
407
408 fn create_dummy_session(
409 &self,
410 _database_id: DatabaseId,
411 _user_name: u32,
412 ) -> Result<Arc<Self::Session>, Self::Error> {
413 unimplemented!()
414 }
415
416 fn connect(
417 &self,
418 _database: &str,
419 _user_name: &str,
420 _peer_addr: crate::net::AddressRef,
421 ) -> Result<Arc<Self::Session>, Self::Error> {
422 Ok(Arc::new(MockSession {}))
423 }
424
425 fn cancel_queries_in_session(&self, _session_id: SessionId) {
426 todo!()
427 }
428
429 fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
430 todo!()
431 }
432
433 fn end_session(&self, _session: &Self::Session) {}
434 }
435
436 impl Session for MockSession {
437 type Error = BoxedError;
438 type Portal = String;
439 type PreparedStatement = String;
440 type ValuesStream = BoxStream<'static, RowSetResult>;
441
442 async fn run_one_query(
443 self: Arc<Self>,
444 _stmt: Statement,
445 _format: types::Format,
446 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
447 Ok(PgResponse::builder(StatementType::SELECT)
448 .values(
449 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
450 .boxed(),
451 vec![
452 PgFieldDescriptor::new("".to_owned(), 1043, -1);
455 1
456 ],
457 )
458 .into())
459 }
460
461 async fn parse(
462 self: Arc<Self>,
463 _sql: Option<Statement>,
464 _params_types: Vec<Option<DataType>>,
465 ) -> Result<String, Self::Error> {
466 Ok(String::new())
467 }
468
469 fn bind(
470 self: Arc<Self>,
471 _prepare_statement: String,
472 _params: Vec<Option<Bytes>>,
473 _param_formats: Vec<types::Format>,
474 _result_formats: Vec<types::Format>,
475 ) -> Result<String, Self::Error> {
476 Ok(String::new())
477 }
478
479 async fn execute(
480 self: Arc<Self>,
481 _portal: String,
482 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
483 Ok(PgResponse::builder(StatementType::SELECT)
484 .values(
485 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
486 .boxed(),
487 vec![
488 PgFieldDescriptor::new("".to_owned(), 1043, -1);
491 1
492 ],
493 )
494 .into())
495 }
496
497 fn describe_statement(
498 self: Arc<Self>,
499 _statement: String,
500 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error> {
501 Ok((
502 vec![],
503 vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
504 ))
505 }
506
507 fn describe_portal(
508 self: Arc<Self>,
509 _portal: String,
510 ) -> Result<Vec<PgFieldDescriptor>, Self::Error> {
511 Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
512 }
513
514 fn user_authenticator(&self) -> &UserAuthenticator {
515 &UserAuthenticator::None
516 }
517
518 fn id(&self) -> SessionId {
519 (0, 0)
520 }
521
522 fn get_config(&self, key: &str) -> Result<String, Self::Error> {
523 match key {
524 "timezone" => Ok("UTC".to_owned()),
525 _ => Err(format!("Unknown config key: {key}").into()),
526 }
527 }
528
529 fn set_config(&self, _key: &str, _value: String) -> Result<String, Self::Error> {
530 Ok("".to_owned())
531 }
532
533 async fn next_notice(self: &Arc<Self>) -> String {
534 std::future::pending().await
535 }
536
537 fn transaction_status(&self) -> TransactionStatus {
538 TransactionStatus::Idle
539 }
540
541 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
542 let exec_context = Arc::new(ExecContext {
543 running_sql: sql,
544 last_instant: Instant::now(),
545 last_idle_instant: Default::default(),
546 });
547 ExecContextGuard::new(exec_context)
548 }
549
550 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
551 Ok(())
552 }
553 }
554
555 async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
556 let bind_addr = bind_addr.into();
557 let pg_config = pg_config.into();
558
559 let session_mgr = MockSessionManager {};
560 tokio::spawn(async move {
561 pg_serve(
562 &bind_addr,
563 socket2::TcpKeepalive::new(),
564 Arc::new(session_mgr),
565 ConnectionContext {
566 tls_config: None,
567 redact_sql_option_keywords: None,
568 message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
569 .into(),
570 },
571 CancellationToken::new(), )
573 .await
574 });
575 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
577
578 let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
580
581 tokio::spawn(async move {
584 if let Err(e) = connection.await {
585 eprintln!("connection error: {}", e);
586 }
587 });
588
589 let rows = client
590 .simple_query("SELECT ''")
591 .await
592 .expect("Error executing query");
593 assert_eq!(rows.len(), 2);
595
596 let rows = client
597 .query("SELECT ''", &[])
598 .await
599 .expect("Error executing query");
600 assert_eq!(rows.len(), 1);
601 }
602
603 #[tokio::test]
604 async fn test_query_tcp() {
605 do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
606 }
607
608 #[cfg(not(madsim))]
609 #[tokio::test]
610 async fn test_query_unix() {
611 let port: i16 = 10000;
612 let dir = tempfile::TempDir::new().unwrap();
613 let sock = dir.path().join(format!(".s.PGSQL.{port}"));
614
615 do_test_query(
616 format!("unix:{}", sock.to_str().unwrap()),
617 format!("host={} port={}", dir.path().to_str().unwrap(), port),
618 )
619 .await;
620 }
621
622 mod jwt_validation_tests {
623 use std::collections::HashMap;
624 use std::time::{SystemTime, UNIX_EPOCH};
625
626 use base64::Engine;
627 use jsonwebtoken::{Algorithm, EncodingKey, Header};
628 use rsa::pkcs1::EncodeRsaPrivateKey;
629 use rsa::traits::PublicKeyParts;
630 use rsa::{RsaPrivateKey, RsaPublicKey};
631 use serde_json::json;
632
633 use crate::pg_server::{Jwk, Jwks, validate_jwt_with_jwks};
634
635 fn create_test_rsa_keys() -> (RsaPrivateKey, RsaPublicKey) {
636 let mut rng = rand::thread_rng();
637 let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate a key");
638 let public_key = RsaPublicKey::from(&private_key);
639 (private_key, public_key)
640 }
641
642 fn create_test_jwks(public_key: &RsaPublicKey, kid: &str, alg: &str) -> Jwks {
643 let n = base64::engine::general_purpose::URL_SAFE_NO_PAD
644 .encode(public_key.n().to_bytes_be());
645 let e = base64::engine::general_purpose::URL_SAFE_NO_PAD
646 .encode(public_key.e().to_bytes_be());
647
648 Jwks {
649 keys: vec![Jwk {
650 kid: kid.to_owned(),
651 alg: alg.to_owned(),
652 n,
653 e,
654 }],
655 }
656 }
657
658 fn create_jwt_token(
659 private_key: &RsaPrivateKey,
660 kid: &str,
661 algorithm: Algorithm,
662 issuer: &str,
663 audience: Option<&str>,
664 exp: u64,
665 additional_claims: HashMap<String, serde_json::Value>,
666 ) -> String {
667 let mut header = Header::new(algorithm);
668 header.kid = Some(kid.to_owned());
669
670 let mut claims = json!({
671 "iss": issuer,
672 "exp": exp,
673 });
674
675 if let Some(aud) = audience {
676 claims["aud"] = json!(aud);
677 }
678
679 for (key, value) in additional_claims {
680 claims[key] = value;
681 }
682
683 let encoding_key = EncodingKey::from_rsa_pem(
684 private_key
685 .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
686 .unwrap()
687 .as_bytes(),
688 )
689 .unwrap();
690
691 jsonwebtoken::encode(&header, &claims, &encoding_key).unwrap()
692 }
693
694 fn get_future_timestamp() -> u64 {
695 SystemTime::now()
696 .duration_since(UNIX_EPOCH)
697 .unwrap()
698 .as_secs()
699 + 3600 }
701
702 fn get_past_timestamp() -> u64 {
703 SystemTime::now()
704 .duration_since(UNIX_EPOCH)
705 .unwrap()
706 .as_secs()
707 - 3600 }
709
710 #[test]
711 fn test_jwt_with_invalid_audience() {
712 let (private_key, public_key) = create_test_rsa_keys();
713 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
714
715 let metadata = HashMap::new();
716
717 let jwt = create_jwt_token(
718 &private_key,
719 "test-kid",
720 Algorithm::RS256,
721 "https://test-issuer.com",
722 Some("urn:risingwave:cluster:wrong-cluster-id"),
723 get_future_timestamp(),
724 HashMap::new(),
725 );
726
727 let result = validate_jwt_with_jwks(
728 &jwt,
729 &jwks,
730 "https://test-issuer.com",
731 "test-cluster-id",
732 &metadata,
733 );
734
735 let error = result.unwrap_err();
736 assert!(error.to_string().contains("InvalidAudience"));
737 }
738
739 #[test]
740 fn test_jwt_with_missing_audience() {
741 let (private_key, public_key) = create_test_rsa_keys();
742 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
743
744 let metadata = HashMap::new();
745
746 let jwt = create_jwt_token(
747 &private_key,
748 "test-kid",
749 Algorithm::RS256,
750 "https://test-issuer.com",
751 None, get_future_timestamp(),
753 HashMap::new(),
754 );
755
756 let result = validate_jwt_with_jwks(
757 &jwt,
758 &jwks,
759 "https://test-issuer.com",
760 "test-cluster-id",
761 &metadata,
762 );
763
764 let error = result.unwrap_err();
765 assert!(error.to_string().contains("Missing required claim: aud"));
766 }
767
768 #[test]
769 fn test_jwt_with_invalid_issuer() {
770 let (private_key, public_key) = create_test_rsa_keys();
771 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
772
773 let metadata = HashMap::new();
774
775 let jwt = create_jwt_token(
776 &private_key,
777 "test-kid",
778 Algorithm::RS256,
779 "https://wrong-issuer.com",
780 Some("urn:risingwave:cluster:test-cluster-id"),
781 get_future_timestamp(),
782 HashMap::new(),
783 );
784
785 let result = validate_jwt_with_jwks(
786 &jwt,
787 &jwks,
788 "https://test-issuer.com",
789 "test-cluster-id",
790 &metadata,
791 );
792
793 let error = result.unwrap_err();
794 assert!(error.to_string().contains("InvalidIssuer"));
795 }
796
797 #[test]
798 fn test_jwt_with_kid_not_found_in_jwks() {
799 let (private_key, public_key) = create_test_rsa_keys();
800 let jwks = create_test_jwks(&public_key, "different-kid", "RS256");
801
802 let metadata = HashMap::new();
803
804 let jwt = create_jwt_token(
805 &private_key,
806 "missing-kid",
807 Algorithm::RS256,
808 "https://test-issuer.com",
809 Some("urn:risingwave:cluster:test-cluster-id"),
810 get_future_timestamp(),
811 HashMap::new(),
812 );
813
814 let result = validate_jwt_with_jwks(
815 &jwt,
816 &jwks,
817 "https://test-issuer.com",
818 "test-cluster-id",
819 &metadata,
820 );
821
822 let error = result.unwrap_err();
823 assert!(
824 error
825 .to_string()
826 .contains("No matching key found in JWKS for kid: 'missing-kid'")
827 );
828 }
829
830 #[test]
831 fn test_jwt_with_expired_token() {
832 let (private_key, public_key) = create_test_rsa_keys();
833 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
834
835 let metadata = HashMap::new();
836
837 let jwt = create_jwt_token(
838 &private_key,
839 "test-kid",
840 Algorithm::RS256,
841 "https://test-issuer.com",
842 Some("urn:risingwave:cluster:test-cluster-id"),
843 get_past_timestamp(), HashMap::new(),
845 );
846
847 let result = validate_jwt_with_jwks(
848 &jwt,
849 &jwks,
850 "https://test-issuer.com",
851 "test-cluster-id",
852 &metadata,
853 );
854
855 let error = result.unwrap_err();
856 assert!(error.to_string().contains("ExpiredSignature"));
857 }
858
859 #[test]
860 fn test_jwt_with_invalid_signature() {
861 let (_, public_key) = create_test_rsa_keys();
862 let (wrong_private_key, _) = create_test_rsa_keys(); let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
864
865 let metadata = HashMap::new();
866
867 let jwt = create_jwt_token(
869 &wrong_private_key,
870 "test-kid",
871 Algorithm::RS256,
872 "https://test-issuer.com",
873 Some("urn:risingwave:cluster:test-cluster-id"),
874 get_future_timestamp(),
875 HashMap::new(),
876 );
877
878 let result = validate_jwt_with_jwks(
879 &jwt,
880 &jwks,
881 "https://test-issuer.com",
882 "test-cluster-id",
883 &metadata,
884 );
885
886 let error = result.unwrap_err();
887 assert!(error.to_string().contains("InvalidSignature"));
888 }
889
890 #[test]
891 fn test_metadata_validation_success() {
892 let (private_key, public_key) = create_test_rsa_keys();
893 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
894
895 let mut metadata = HashMap::new();
896 metadata.insert("role".to_owned(), "admin".to_owned());
897 metadata.insert("department".to_owned(), "security".to_owned());
898
899 let mut claims = HashMap::new();
900 claims.insert("role".to_owned(), json!("admin"));
901 claims.insert("department".to_owned(), json!("security"));
902 claims.insert("extra_claim".to_owned(), json!("ignored")); let jwt = create_jwt_token(
905 &private_key,
906 "test-kid",
907 Algorithm::RS256,
908 "https://test-issuer.com",
909 Some("urn:risingwave:cluster:test-cluster-id"),
910 get_future_timestamp(),
911 claims,
912 );
913
914 let result = validate_jwt_with_jwks(
915 &jwt,
916 &jwks,
917 "https://test-issuer.com",
918 "test-cluster-id",
919 &metadata,
920 );
921
922 assert!(result.unwrap());
923 }
924
925 #[test]
926 fn test_metadata_validation_failure() {
927 let (private_key, public_key) = create_test_rsa_keys();
928 let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
929
930 let mut metadata = HashMap::new();
931 metadata.insert("role".to_owned(), "admin".to_owned());
932 metadata.insert("department".to_owned(), "security".to_owned());
933
934 let mut claims = HashMap::new();
935 claims.insert("role".to_owned(), json!("user")); claims.insert("department".to_owned(), json!("security"));
937
938 let jwt = create_jwt_token(
939 &private_key,
940 "test-kid",
941 Algorithm::RS256,
942 "https://test-issuer.com",
943 Some("urn:risingwave:cluster:test-cluster-id"),
944 get_future_timestamp(),
945 claims,
946 );
947
948 let result = validate_jwt_with_jwks(
949 &jwt,
950 &jwks,
951 "https://test-issuer.com",
952 "test-cluster-id",
953 &metadata,
954 );
955
956 let error = result.unwrap_err();
957 assert_eq!(
958 error.to_string(),
959 "metadata in jwt does not match with metadata declared with user"
960 );
961 }
962 }
963}