pgwire/
pg_server.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::config::HbaEntry;
25use risingwave_common::id::DatabaseId;
26use risingwave_common::types::DataType;
27use risingwave_common::util::runtime::BackgroundShutdownRuntime;
28use risingwave_common::util::tokio_util::sync::CancellationToken;
29use risingwave_sqlparser::ast::Statement;
30use serde::Deserialize;
31use thiserror_ext::AsReport;
32
33use crate::error::{PsqlError, PsqlResult};
34use crate::ldap_auth::LdapAuthenticator;
35use crate::net::{AddressRef, Listener, TcpKeepalive};
36use crate::pg_field_descriptor::PgFieldDescriptor;
37use crate::pg_message::TransactionStatus;
38use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
39use crate::pg_response::{PgResponse, ValuesStream};
40use crate::types::Format;
41
42pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
43type ProcessId = i32;
44type SecretKey = i32;
45pub type SessionId = (ProcessId, SecretKey);
46
47/// The interface for a database system behind pgwire protocol.
48/// We can mock it for testing purpose.
49pub trait SessionManager: Send + Sync + 'static {
50    type Error: Into<BoxedError>;
51    type Session: Session<Error = Self::Error>;
52
53    /// In the process of auto schema change, we need a dummy session to access
54    /// catalog information in frontend and build a replace plan for the table.
55    fn create_dummy_session(
56        &self,
57        database_id: DatabaseId,
58    ) -> Result<Arc<Self::Session>, Self::Error>;
59
60    fn connect(
61        &self,
62        database: &str,
63        user_name: &str,
64        peer_addr: AddressRef,
65    ) -> Result<Arc<Self::Session>, Self::Error>;
66
67    fn cancel_queries_in_session(&self, session_id: SessionId);
68
69    fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
70
71    fn end_session(&self, session: &Self::Session);
72
73    /// Run some cleanup tasks before the server shutdown.
74    fn shutdown(&self) -> impl Future<Output = ()> + Send {
75        async {}
76    }
77}
78
79/// A psql connection. Each connection binds with a database. Switching database will need to
80/// recreate another connection.
81pub trait Session: Send + Sync {
82    type Error: Into<BoxedError>;
83    type ValuesStream: ValuesStream;
84    type PreparedStatement: Send + Clone + 'static;
85    type Portal: Send + Clone + std::fmt::Display + 'static;
86
87    /// The str sql can not use the unparse from AST: There is some problem when dealing with create
88    /// view, see <https://github.com/risingwavelabs/risingwave/issues/6801>.
89    fn run_one_query(
90        self: Arc<Self>,
91        stmt: Statement,
92        format: Format,
93    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
94
95    fn parse(
96        self: Arc<Self>,
97        sql: Option<Statement>,
98        params_types: Vec<Option<DataType>>,
99    ) -> impl Future<Output = Result<Self::PreparedStatement, Self::Error>> + Send;
100
101    /// Receive the next notice message to send to the client.
102    ///
103    /// This function should be cancellation-safe.
104    fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
105
106    fn bind(
107        self: Arc<Self>,
108        prepare_statement: Self::PreparedStatement,
109        params: Vec<Option<Bytes>>,
110        param_formats: Vec<Format>,
111        result_formats: Vec<Format>,
112    ) -> Result<Self::Portal, Self::Error>;
113
114    fn execute(
115        self: Arc<Self>,
116        portal: Self::Portal,
117    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
118
119    fn describe_statement(
120        self: Arc<Self>,
121        prepare_statement: Self::PreparedStatement,
122    ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error>;
123
124    fn describe_portal(
125        self: Arc<Self>,
126        portal: Self::Portal,
127    ) -> Result<Vec<PgFieldDescriptor>, Self::Error>;
128
129    fn user_authenticator(&self) -> &UserAuthenticator;
130
131    fn id(&self) -> SessionId;
132
133    fn get_config(&self, key: &str) -> Result<String, Self::Error>;
134
135    fn set_config(&self, key: &str, value: String) -> Result<String, Self::Error>;
136
137    fn transaction_status(&self) -> TransactionStatus;
138
139    fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
140
141    fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
142
143    fn user(&self) -> String;
144}
145
146/// Each session could run different SQLs multiple times.
147/// `ExecContext` represents the lifetime of a running SQL in the current session.
148pub struct ExecContext {
149    pub running_sql: Arc<str>,
150    /// The instant of the running sql
151    pub last_instant: Instant,
152    /// A reference used to update when `ExecContext` is dropped
153    pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
154}
155
156/// `ExecContextGuard` holds a `Arc` pointer. Once `ExecContextGuard` is dropped,
157/// the inner `Arc<ExecContext>` should not be referred anymore, so that its `Weak` reference (used in `SessionImpl`) will be the same lifecycle of the running sql execution context.
158pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
159
160impl ExecContextGuard {
161    pub fn new(exec_context: Arc<ExecContext>) -> Self {
162        Self(exec_context)
163    }
164}
165
166impl Drop for ExecContext {
167    fn drop(&mut self) {
168        *self.last_idle_instant.lock() = Some(Instant::now());
169    }
170}
171
172#[derive(Debug, Clone)]
173pub enum UserAuthenticator {
174    // No need to authenticate.
175    None,
176    // raw password in clear-text form.
177    ClearText(Vec<u8>),
178    // password encrypted with random salt.
179    Md5WithSalt {
180        encrypted_password: Vec<u8>,
181        salt: [u8; 4],
182    },
183    OAuth {
184        metadata: HashMap<String, String>,
185        cluster_id: String,
186    },
187    Ldap(String, HbaEntry),
188}
189
190/// A JWK Set is a JSON object that represents a set of JWKs.
191/// The JSON object MUST have a "keys" member, with its value being an array of JWKs.
192/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-5> for more details.
193#[derive(Debug, Deserialize)]
194struct Jwks {
195    keys: Vec<Jwk>,
196}
197
198/// A JSON Web Key (JWK) is a JSON object that represents a cryptographic key.
199/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-4> for more details.
200#[derive(Debug, Deserialize)]
201struct Jwk {
202    kid: String, // Key ID
203    alg: String, // Algorithm
204    n: String,   // Modulus
205    e: String,   // Exponent
206}
207
208async fn validate_jwt(
209    jwt: &str,
210    jwks_url: &str,
211    issuer: &str,
212    cluster_id: &str,
213    metadata: &HashMap<String, String>,
214) -> Result<bool, BoxedError> {
215    let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
216    validate_jwt_with_jwks(jwt, &jwks, issuer, cluster_id, metadata)
217}
218
219fn audience_from_cluster_id(cluster_id: &str) -> String {
220    format!("urn:risingwave:cluster:{}", cluster_id)
221}
222
223fn validate_jwt_with_jwks(
224    jwt: &str,
225    jwks: &Jwks,
226    issuer: &str,
227    cluster_id: &str,
228    metadata: &HashMap<String, String>,
229) -> Result<bool, BoxedError> {
230    let header = decode_header(jwt)?;
231
232    // 1. Retrieve the kid from the header to find the right JWK in the JWK Set.
233    let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
234    let jwk = jwks
235        .keys
236        .iter()
237        .find(|k| k.kid == kid)
238        .ok_or(format!("No matching key found in JWKS for kid: '{}'", kid))?;
239
240    // 2. Check if the algorithms are matched.
241    if Algorithm::from_str(&jwk.alg)? != header.alg {
242        return Err("alg in jwt header does not match with alg in jwk".into());
243    }
244
245    // 3. Decode the JWT and validate the claims.
246    let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
247    let mut validation = Validation::new(header.alg);
248    validation.set_issuer(&[issuer]);
249    validation.set_audience(&[audience_from_cluster_id(cluster_id)]); // JWT 'aud' claim must match cluster_id
250    validation.set_required_spec_claims(&["exp", "iss", "aud"]);
251    let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
252
253    // 4. Check if the metadata in the token matches.
254    if !metadata.iter().all(
255        |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
256    ) {
257        return Err("metadata in jwt does not match with metadata declared with user".into());
258    }
259    Ok(true)
260}
261
262impl UserAuthenticator {
263    pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
264        let success = match self {
265            UserAuthenticator::None => true,
266            UserAuthenticator::ClearText(text) => password == text,
267            UserAuthenticator::Md5WithSalt {
268                encrypted_password, ..
269            } => encrypted_password == password,
270            UserAuthenticator::OAuth {
271                metadata,
272                cluster_id,
273            } => {
274                let mut metadata = metadata.clone();
275                let jwks_url = metadata.remove("jwks_url").unwrap();
276                let issuer = metadata.remove("issuer").unwrap();
277                validate_jwt(
278                    &String::from_utf8_lossy(password),
279                    &jwks_url,
280                    &issuer,
281                    cluster_id,
282                    &metadata,
283                )
284                .await
285                .map_err(PsqlError::StartupError)?
286            }
287            UserAuthenticator::Ldap(user_name, hba_entry) => {
288                let ldap_auth = LdapAuthenticator::new(hba_entry)?;
289                // Convert password to string, defaulting to empty if not valid UTF-8
290                let password_str = String::from_utf8_lossy(password).into_owned();
291                ldap_auth.authenticate(user_name, &password_str).await?
292            }
293        };
294        if !success {
295            return Err(PsqlError::PasswordError);
296        }
297        Ok(())
298    }
299}
300
301/// Binds a Tcp or Unix listener at `addr`. Spawn a coroutine to serve every new connection.
302///
303/// Returns when the `shutdown` token is triggered.
304pub async fn pg_serve(
305    addr: &str,
306    tcp_keepalive: TcpKeepalive,
307    session_mgr: Arc<impl SessionManager>,
308    context: ConnectionContext,
309    shutdown: CancellationToken,
310) -> Result<(), BoxedError> {
311    let listener = Listener::bind(addr).await?;
312    tracing::info!(addr, "server started");
313
314    let acceptor_runtime = BackgroundShutdownRuntime::from({
315        let mut builder = tokio::runtime::Builder::new_multi_thread();
316        builder.worker_threads(1);
317        builder
318            .thread_name("rw-acceptor")
319            .enable_all()
320            .build()
321            .unwrap()
322    });
323
324    #[cfg(not(madsim))]
325    let worker_runtime = tokio::runtime::Handle::current();
326    #[cfg(madsim)]
327    let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
328    let session_mgr_clone = session_mgr.clone();
329    let f = async move {
330        loop {
331            let conn_ret = listener.accept(&tcp_keepalive).await;
332            match conn_ret {
333                Ok((stream, peer_addr)) => {
334                    tracing::info!(%peer_addr, "accept connection");
335                    worker_runtime.spawn(handle_connection(
336                        stream,
337                        session_mgr_clone.clone(),
338                        Arc::new(peer_addr),
339                        context.clone(),
340                    ));
341                }
342
343                Err(e) => {
344                    tracing::error!(error = %e.as_report(), "failed to accept connection",);
345                }
346            }
347        }
348    };
349    acceptor_runtime.spawn(f);
350
351    // Wait for the shutdown signal.
352    shutdown.cancelled().await;
353
354    // Stop accepting new connections.
355    drop(acceptor_runtime);
356    // Shutdown session manager, typically close all existing sessions.
357    session_mgr.shutdown().await;
358
359    Ok(())
360}
361
362pub async fn handle_connection<S, SM>(
363    stream: S,
364    session_mgr: Arc<SM>,
365    peer_addr: AddressRef,
366    context: ConnectionContext,
367) where
368    S: PgByteStream,
369    SM: SessionManager,
370{
371    PgProtocol::new(stream, session_mgr, peer_addr, context)
372        .run()
373        .await;
374}
375#[cfg(test)]
376mod tests {
377    use std::sync::Arc;
378    use std::time::Instant;
379
380    use bytes::Bytes;
381    use futures::StreamExt;
382    use futures::stream::BoxStream;
383    use risingwave_common::id::DatabaseId;
384    use risingwave_common::types::DataType;
385    use risingwave_common::util::tokio_util::sync::CancellationToken;
386    use risingwave_sqlparser::ast::Statement;
387    use tokio_postgres::NoTls;
388
389    use crate::error::PsqlResult;
390    use crate::memory_manager::MessageMemoryManager;
391    use crate::pg_field_descriptor::PgFieldDescriptor;
392    use crate::pg_message::TransactionStatus;
393    use crate::pg_protocol::ConnectionContext;
394    use crate::pg_response::{PgResponse, RowSetResult, StatementType};
395    use crate::pg_server::{
396        BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
397        UserAuthenticator, pg_serve,
398    };
399    use crate::types;
400    use crate::types::Row;
401
402    struct MockSessionManager {}
403    struct MockSession {}
404
405    impl SessionManager for MockSessionManager {
406        type Error = BoxedError;
407        type Session = MockSession;
408
409        fn create_dummy_session(
410            &self,
411            _database_id: DatabaseId,
412        ) -> Result<Arc<Self::Session>, Self::Error> {
413            unimplemented!()
414        }
415
416        fn connect(
417            &self,
418            _database: &str,
419            _user_name: &str,
420            _peer_addr: crate::net::AddressRef,
421        ) -> Result<Arc<Self::Session>, Self::Error> {
422            Ok(Arc::new(MockSession {}))
423        }
424
425        fn cancel_queries_in_session(&self, _session_id: SessionId) {
426            todo!()
427        }
428
429        fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
430            todo!()
431        }
432
433        fn end_session(&self, _session: &Self::Session) {}
434    }
435
436    impl Session for MockSession {
437        type Error = BoxedError;
438        type Portal = String;
439        type PreparedStatement = String;
440        type ValuesStream = BoxStream<'static, RowSetResult>;
441
442        async fn run_one_query(
443            self: Arc<Self>,
444            _stmt: Statement,
445            _format: types::Format,
446        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
447            Ok(PgResponse::builder(StatementType::SELECT)
448                .values(
449                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
450                        .boxed(),
451                    vec![
452                        // 1043 is the oid of varchar type.
453                        // -1 is the type len of varchar type.
454                        PgFieldDescriptor::new("".to_owned(), 1043, -1);
455                        1
456                    ],
457                )
458                .into())
459        }
460
461        async fn parse(
462            self: Arc<Self>,
463            _sql: Option<Statement>,
464            _params_types: Vec<Option<DataType>>,
465        ) -> Result<String, Self::Error> {
466            Ok(String::new())
467        }
468
469        fn bind(
470            self: Arc<Self>,
471            _prepare_statement: String,
472            _params: Vec<Option<Bytes>>,
473            _param_formats: Vec<types::Format>,
474            _result_formats: Vec<types::Format>,
475        ) -> Result<String, Self::Error> {
476            Ok(String::new())
477        }
478
479        async fn execute(
480            self: Arc<Self>,
481            _portal: String,
482        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
483            Ok(PgResponse::builder(StatementType::SELECT)
484                .values(
485                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
486                        .boxed(),
487                    vec![
488                    // 1043 is the oid of varchar type.
489                    // -1 is the type len of varchar type.
490                    PgFieldDescriptor::new("".to_owned(), 1043, -1);
491                    1
492                ],
493                )
494                .into())
495        }
496
497        fn describe_statement(
498            self: Arc<Self>,
499            _statement: String,
500        ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error> {
501            Ok((
502                vec![],
503                vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
504            ))
505        }
506
507        fn describe_portal(
508            self: Arc<Self>,
509            _portal: String,
510        ) -> Result<Vec<PgFieldDescriptor>, Self::Error> {
511            Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
512        }
513
514        fn user_authenticator(&self) -> &UserAuthenticator {
515            &UserAuthenticator::None
516        }
517
518        fn id(&self) -> SessionId {
519            (0, 0)
520        }
521
522        fn get_config(&self, key: &str) -> Result<String, Self::Error> {
523            match key {
524                "timezone" => Ok("UTC".to_owned()),
525                _ => Err(format!("Unknown config key: {key}").into()),
526            }
527        }
528
529        fn set_config(&self, _key: &str, _value: String) -> Result<String, Self::Error> {
530            Ok("".to_owned())
531        }
532
533        async fn next_notice(self: &Arc<Self>) -> String {
534            std::future::pending().await
535        }
536
537        fn transaction_status(&self) -> TransactionStatus {
538            TransactionStatus::Idle
539        }
540
541        fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
542            let exec_context = Arc::new(ExecContext {
543                running_sql: sql,
544                last_instant: Instant::now(),
545                last_idle_instant: Default::default(),
546            });
547            ExecContextGuard::new(exec_context)
548        }
549
550        fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
551            Ok(())
552        }
553
554        fn user(&self) -> String {
555            "mock".to_owned()
556        }
557    }
558
559    async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
560        let bind_addr = bind_addr.into();
561        let pg_config = pg_config.into();
562
563        let session_mgr = MockSessionManager {};
564        tokio::spawn(async move {
565            pg_serve(
566                &bind_addr,
567                socket2::TcpKeepalive::new(),
568                Arc::new(session_mgr),
569                ConnectionContext {
570                    tls_config: None,
571                    redact_sql_option_keywords: None,
572                    message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
573                        .into(),
574                },
575                CancellationToken::new(), // dummy
576            )
577            .await
578        });
579        // wait for server to start
580        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
581
582        // Connect to the database.
583        let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
584
585        // The connection object performs the actual communication with the database,
586        // so spawn it off to run on its own.
587        tokio::spawn(async move {
588            if let Err(e) = connection.await {
589                eprintln!("connection error: {}", e);
590            }
591        });
592
593        let rows = client
594            .simple_query("SELECT ''")
595            .await
596            .expect("Error executing query");
597        // Row + CommandComplete
598        assert_eq!(rows.len(), 2);
599
600        let rows = client
601            .query("SELECT ''", &[])
602            .await
603            .expect("Error executing query");
604        assert_eq!(rows.len(), 1);
605    }
606
607    #[tokio::test]
608    async fn test_query_tcp() {
609        do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
610    }
611
612    #[cfg(not(madsim))]
613    #[tokio::test]
614    async fn test_query_unix() {
615        let port: i16 = 10000;
616        let dir = tempfile::TempDir::new().unwrap();
617        let sock = dir.path().join(format!(".s.PGSQL.{port}"));
618
619        do_test_query(
620            format!("unix:{}", sock.to_str().unwrap()),
621            format!("host={} port={}", dir.path().to_str().unwrap(), port),
622        )
623        .await;
624    }
625
626    mod jwt_validation_tests {
627        use std::collections::HashMap;
628        use std::time::{SystemTime, UNIX_EPOCH};
629
630        use base64::Engine;
631        use jsonwebtoken::{Algorithm, EncodingKey, Header};
632        use rsa::pkcs1::EncodeRsaPrivateKey;
633        use rsa::traits::PublicKeyParts;
634        use rsa::{RsaPrivateKey, RsaPublicKey};
635        use serde_json::json;
636
637        use crate::pg_server::{Jwk, Jwks, validate_jwt_with_jwks};
638
639        fn create_test_rsa_keys() -> (RsaPrivateKey, RsaPublicKey) {
640            let mut rng = rand::thread_rng();
641            let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate a key");
642            let public_key = RsaPublicKey::from(&private_key);
643            (private_key, public_key)
644        }
645
646        fn create_test_jwks(public_key: &RsaPublicKey, kid: &str, alg: &str) -> Jwks {
647            let n = base64::engine::general_purpose::URL_SAFE_NO_PAD
648                .encode(public_key.n().to_bytes_be());
649            let e = base64::engine::general_purpose::URL_SAFE_NO_PAD
650                .encode(public_key.e().to_bytes_be());
651
652            Jwks {
653                keys: vec![Jwk {
654                    kid: kid.to_owned(),
655                    alg: alg.to_owned(),
656                    n,
657                    e,
658                }],
659            }
660        }
661
662        fn create_jwt_token(
663            private_key: &RsaPrivateKey,
664            kid: &str,
665            algorithm: Algorithm,
666            issuer: &str,
667            audience: Option<&str>,
668            exp: u64,
669            additional_claims: HashMap<String, serde_json::Value>,
670        ) -> String {
671            let mut header = Header::new(algorithm);
672            header.kid = Some(kid.to_owned());
673
674            let mut claims = json!({
675                "iss": issuer,
676                "exp": exp,
677            });
678
679            if let Some(aud) = audience {
680                claims["aud"] = json!(aud);
681            }
682
683            for (key, value) in additional_claims {
684                claims[key] = value;
685            }
686
687            let encoding_key = EncodingKey::from_rsa_pem(
688                private_key
689                    .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
690                    .unwrap()
691                    .as_bytes(),
692            )
693            .unwrap();
694
695            jsonwebtoken::encode(&header, &claims, &encoding_key).unwrap()
696        }
697
698        fn get_future_timestamp() -> u64 {
699            SystemTime::now()
700                .duration_since(UNIX_EPOCH)
701                .unwrap()
702                .as_secs()
703                + 3600 // 1 hour from now
704        }
705
706        fn get_past_timestamp() -> u64 {
707            SystemTime::now()
708                .duration_since(UNIX_EPOCH)
709                .unwrap()
710                .as_secs()
711                - 3600 // 1 hour ago
712        }
713
714        #[test]
715        fn test_jwt_with_invalid_audience() {
716            let (private_key, public_key) = create_test_rsa_keys();
717            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
718
719            let metadata = HashMap::new();
720
721            let jwt = create_jwt_token(
722                &private_key,
723                "test-kid",
724                Algorithm::RS256,
725                "https://test-issuer.com",
726                Some("urn:risingwave:cluster:wrong-cluster-id"),
727                get_future_timestamp(),
728                HashMap::new(),
729            );
730
731            let result = validate_jwt_with_jwks(
732                &jwt,
733                &jwks,
734                "https://test-issuer.com",
735                "test-cluster-id",
736                &metadata,
737            );
738
739            let error = result.unwrap_err();
740            assert!(error.to_string().contains("InvalidAudience"));
741        }
742
743        #[test]
744        fn test_jwt_with_missing_audience() {
745            let (private_key, public_key) = create_test_rsa_keys();
746            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
747
748            let metadata = HashMap::new();
749
750            let jwt = create_jwt_token(
751                &private_key,
752                "test-kid",
753                Algorithm::RS256,
754                "https://test-issuer.com",
755                None, // No audience claim
756                get_future_timestamp(),
757                HashMap::new(),
758            );
759
760            let result = validate_jwt_with_jwks(
761                &jwt,
762                &jwks,
763                "https://test-issuer.com",
764                "test-cluster-id",
765                &metadata,
766            );
767
768            let error = result.unwrap_err();
769            assert!(error.to_string().contains("Missing required claim: aud"));
770        }
771
772        #[test]
773        fn test_jwt_with_invalid_issuer() {
774            let (private_key, public_key) = create_test_rsa_keys();
775            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
776
777            let metadata = HashMap::new();
778
779            let jwt = create_jwt_token(
780                &private_key,
781                "test-kid",
782                Algorithm::RS256,
783                "https://wrong-issuer.com",
784                Some("urn:risingwave:cluster:test-cluster-id"),
785                get_future_timestamp(),
786                HashMap::new(),
787            );
788
789            let result = validate_jwt_with_jwks(
790                &jwt,
791                &jwks,
792                "https://test-issuer.com",
793                "test-cluster-id",
794                &metadata,
795            );
796
797            let error = result.unwrap_err();
798            assert!(error.to_string().contains("InvalidIssuer"));
799        }
800
801        #[test]
802        fn test_jwt_with_kid_not_found_in_jwks() {
803            let (private_key, public_key) = create_test_rsa_keys();
804            let jwks = create_test_jwks(&public_key, "different-kid", "RS256");
805
806            let metadata = HashMap::new();
807
808            let jwt = create_jwt_token(
809                &private_key,
810                "missing-kid",
811                Algorithm::RS256,
812                "https://test-issuer.com",
813                Some("urn:risingwave:cluster:test-cluster-id"),
814                get_future_timestamp(),
815                HashMap::new(),
816            );
817
818            let result = validate_jwt_with_jwks(
819                &jwt,
820                &jwks,
821                "https://test-issuer.com",
822                "test-cluster-id",
823                &metadata,
824            );
825
826            let error = result.unwrap_err();
827            assert!(
828                error
829                    .to_string()
830                    .contains("No matching key found in JWKS for kid: 'missing-kid'")
831            );
832        }
833
834        #[test]
835        fn test_jwt_with_expired_token() {
836            let (private_key, public_key) = create_test_rsa_keys();
837            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
838
839            let metadata = HashMap::new();
840
841            let jwt = create_jwt_token(
842                &private_key,
843                "test-kid",
844                Algorithm::RS256,
845                "https://test-issuer.com",
846                Some("urn:risingwave:cluster:test-cluster-id"),
847                get_past_timestamp(), // Expired token
848                HashMap::new(),
849            );
850
851            let result = validate_jwt_with_jwks(
852                &jwt,
853                &jwks,
854                "https://test-issuer.com",
855                "test-cluster-id",
856                &metadata,
857            );
858
859            let error = result.unwrap_err();
860            assert!(error.to_string().contains("ExpiredSignature"));
861        }
862
863        #[test]
864        fn test_jwt_with_invalid_signature() {
865            let (_, public_key) = create_test_rsa_keys();
866            let (wrong_private_key, _) = create_test_rsa_keys(); // Different key pair
867            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
868
869            let metadata = HashMap::new();
870
871            // Sign with wrong private key
872            let jwt = create_jwt_token(
873                &wrong_private_key,
874                "test-kid",
875                Algorithm::RS256,
876                "https://test-issuer.com",
877                Some("urn:risingwave:cluster:test-cluster-id"),
878                get_future_timestamp(),
879                HashMap::new(),
880            );
881
882            let result = validate_jwt_with_jwks(
883                &jwt,
884                &jwks,
885                "https://test-issuer.com",
886                "test-cluster-id",
887                &metadata,
888            );
889
890            let error = result.unwrap_err();
891            assert!(error.to_string().contains("InvalidSignature"));
892        }
893
894        #[test]
895        fn test_metadata_validation_success() {
896            let (private_key, public_key) = create_test_rsa_keys();
897            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
898
899            let mut metadata = HashMap::new();
900            metadata.insert("role".to_owned(), "admin".to_owned());
901            metadata.insert("department".to_owned(), "security".to_owned());
902
903            let mut claims = HashMap::new();
904            claims.insert("role".to_owned(), json!("admin"));
905            claims.insert("department".to_owned(), json!("security"));
906            claims.insert("extra_claim".to_owned(), json!("ignored")); // Extra claims are fine
907
908            let jwt = create_jwt_token(
909                &private_key,
910                "test-kid",
911                Algorithm::RS256,
912                "https://test-issuer.com",
913                Some("urn:risingwave:cluster:test-cluster-id"),
914                get_future_timestamp(),
915                claims,
916            );
917
918            let result = validate_jwt_with_jwks(
919                &jwt,
920                &jwks,
921                "https://test-issuer.com",
922                "test-cluster-id",
923                &metadata,
924            );
925
926            assert!(result.unwrap());
927        }
928
929        #[test]
930        fn test_metadata_validation_failure() {
931            let (private_key, public_key) = create_test_rsa_keys();
932            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
933
934            let mut metadata = HashMap::new();
935            metadata.insert("role".to_owned(), "admin".to_owned());
936            metadata.insert("department".to_owned(), "security".to_owned());
937
938            let mut claims = HashMap::new();
939            claims.insert("role".to_owned(), json!("user")); // Wrong role
940            claims.insert("department".to_owned(), json!("security"));
941
942            let jwt = create_jwt_token(
943                &private_key,
944                "test-kid",
945                Algorithm::RS256,
946                "https://test-issuer.com",
947                Some("urn:risingwave:cluster:test-cluster-id"),
948                get_future_timestamp(),
949                claims,
950            );
951
952            let result = validate_jwt_with_jwks(
953                &jwt,
954                &jwks,
955                "https://test-issuer.com",
956                "test-cluster-id",
957                &metadata,
958            );
959
960            let error = result.unwrap_err();
961            assert_eq!(
962                error.to_string(),
963                "metadata in jwt does not match with metadata declared with user"
964            );
965        }
966    }
967}