pgwire/
pg_server.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::config::HbaEntry;
25use risingwave_common::id::DatabaseId;
26use risingwave_common::types::DataType;
27use risingwave_common::util::runtime::BackgroundShutdownRuntime;
28use risingwave_common::util::tokio_util::sync::CancellationToken;
29use risingwave_sqlparser::ast::Statement;
30use serde::Deserialize;
31use thiserror_ext::AsReport;
32
33use crate::error::{PsqlError, PsqlResult};
34use crate::ldap_auth::LdapAuthenticator;
35use crate::net::{AddressRef, Listener, TcpKeepalive};
36use crate::pg_field_descriptor::PgFieldDescriptor;
37use crate::pg_message::TransactionStatus;
38use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
39use crate::pg_response::{PgResponse, ValuesStream};
40use crate::types::Format;
41
42pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
43type ProcessId = i32;
44type SecretKey = i32;
45pub type SessionId = (ProcessId, SecretKey);
46
47/// The interface for a database system behind pgwire protocol.
48/// We can mock it for testing purpose.
49pub trait SessionManager: Send + Sync + 'static {
50    type Error: Into<BoxedError>;
51    type Session: Session<Error = Self::Error>;
52
53    /// In the process of auto schema change, we need a dummy session to access
54    /// catalog information in frontend and build a replace plan for the table.
55    fn create_dummy_session(
56        &self,
57        database_id: DatabaseId,
58    ) -> Result<Arc<Self::Session>, Self::Error>;
59
60    fn connect(
61        &self,
62        database: &str,
63        user_name: &str,
64        peer_addr: AddressRef,
65    ) -> Result<Arc<Self::Session>, Self::Error>;
66
67    fn cancel_queries_in_session(&self, session_id: SessionId);
68
69    fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
70
71    fn end_session(&self, session: &Self::Session);
72
73    /// Run some cleanup tasks before the server shutdown.
74    fn shutdown(&self) -> impl Future<Output = ()> + Send {
75        async {}
76    }
77}
78
79/// A psql connection. Each connection binds with a database. Switching database will need to
80/// recreate another connection.
81pub trait Session: Send + Sync {
82    type Error: Into<BoxedError>;
83    type ValuesStream: ValuesStream;
84    type PreparedStatement: Send + Clone + 'static;
85    type Portal: Send + Clone + std::fmt::Display + 'static;
86
87    /// The str sql can not use the unparse from AST: There is some problem when dealing with create
88    /// view, see <https://github.com/risingwavelabs/risingwave/issues/6801>.
89    fn run_one_query(
90        self: Arc<Self>,
91        stmt: Statement,
92        format: Format,
93    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
94
95    fn parse(
96        self: Arc<Self>,
97        sql: Option<Statement>,
98        params_types: Vec<Option<DataType>>,
99    ) -> impl Future<Output = Result<Self::PreparedStatement, Self::Error>> + Send;
100
101    /// Receive the next notice message to send to the client.
102    ///
103    /// This function should be cancellation-safe.
104    fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
105
106    fn bind(
107        self: Arc<Self>,
108        prepare_statement: Self::PreparedStatement,
109        params: Vec<Option<Bytes>>,
110        param_formats: Vec<Format>,
111        result_formats: Vec<Format>,
112    ) -> Result<Self::Portal, Self::Error>;
113
114    fn execute(
115        self: Arc<Self>,
116        portal: Self::Portal,
117    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
118
119    fn describe_statement(
120        self: Arc<Self>,
121        prepare_statement: Self::PreparedStatement,
122    ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error>;
123
124    fn describe_portal(
125        self: Arc<Self>,
126        portal: Self::Portal,
127    ) -> Result<Vec<PgFieldDescriptor>, Self::Error>;
128
129    fn user_authenticator(&self) -> &UserAuthenticator;
130
131    fn id(&self) -> SessionId;
132
133    fn get_config(&self, key: &str) -> Result<String, Self::Error>;
134
135    fn set_config(&self, key: &str, value: String) -> Result<String, Self::Error>;
136
137    fn transaction_status(&self) -> TransactionStatus;
138
139    fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
140
141    fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
142}
143
144/// Each session could run different SQLs multiple times.
145/// `ExecContext` represents the lifetime of a running SQL in the current session.
146pub struct ExecContext {
147    pub running_sql: Arc<str>,
148    /// The instant of the running sql
149    pub last_instant: Instant,
150    /// A reference used to update when `ExecContext` is dropped
151    pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
152}
153
154/// `ExecContextGuard` holds a `Arc` pointer. Once `ExecContextGuard` is dropped,
155/// the inner `Arc<ExecContext>` should not be referred anymore, so that its `Weak` reference (used in `SessionImpl`) will be the same lifecycle of the running sql execution context.
156pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
157
158impl ExecContextGuard {
159    pub fn new(exec_context: Arc<ExecContext>) -> Self {
160        Self(exec_context)
161    }
162}
163
164impl Drop for ExecContext {
165    fn drop(&mut self) {
166        *self.last_idle_instant.lock() = Some(Instant::now());
167    }
168}
169
170#[derive(Debug, Clone)]
171pub enum UserAuthenticator {
172    // No need to authenticate.
173    None,
174    // raw password in clear-text form.
175    ClearText(Vec<u8>),
176    // password encrypted with random salt.
177    Md5WithSalt {
178        encrypted_password: Vec<u8>,
179        salt: [u8; 4],
180    },
181    OAuth {
182        metadata: HashMap<String, String>,
183        cluster_id: String,
184    },
185    Ldap(String, HbaEntry),
186}
187
188/// A JWK Set is a JSON object that represents a set of JWKs.
189/// The JSON object MUST have a "keys" member, with its value being an array of JWKs.
190/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-5> for more details.
191#[derive(Debug, Deserialize)]
192struct Jwks {
193    keys: Vec<Jwk>,
194}
195
196/// A JSON Web Key (JWK) is a JSON object that represents a cryptographic key.
197/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-4> for more details.
198#[derive(Debug, Deserialize)]
199struct Jwk {
200    kid: String, // Key ID
201    alg: String, // Algorithm
202    n: String,   // Modulus
203    e: String,   // Exponent
204}
205
206async fn validate_jwt(
207    jwt: &str,
208    jwks_url: &str,
209    issuer: &str,
210    cluster_id: &str,
211    metadata: &HashMap<String, String>,
212) -> Result<bool, BoxedError> {
213    let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
214    validate_jwt_with_jwks(jwt, &jwks, issuer, cluster_id, metadata)
215}
216
217fn audience_from_cluster_id(cluster_id: &str) -> String {
218    format!("urn:risingwave:cluster:{}", cluster_id)
219}
220
221fn validate_jwt_with_jwks(
222    jwt: &str,
223    jwks: &Jwks,
224    issuer: &str,
225    cluster_id: &str,
226    metadata: &HashMap<String, String>,
227) -> Result<bool, BoxedError> {
228    let header = decode_header(jwt)?;
229
230    // 1. Retrieve the kid from the header to find the right JWK in the JWK Set.
231    let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
232    let jwk = jwks
233        .keys
234        .iter()
235        .find(|k| k.kid == kid)
236        .ok_or(format!("No matching key found in JWKS for kid: '{}'", kid))?;
237
238    // 2. Check if the algorithms are matched.
239    if Algorithm::from_str(&jwk.alg)? != header.alg {
240        return Err("alg in jwt header does not match with alg in jwk".into());
241    }
242
243    // 3. Decode the JWT and validate the claims.
244    let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
245    let mut validation = Validation::new(header.alg);
246    validation.set_issuer(&[issuer]);
247    validation.set_audience(&[audience_from_cluster_id(cluster_id)]); // JWT 'aud' claim must match cluster_id
248    validation.set_required_spec_claims(&["exp", "iss", "aud"]);
249    let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
250
251    // 4. Check if the metadata in the token matches.
252    if !metadata.iter().all(
253        |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
254    ) {
255        return Err("metadata in jwt does not match with metadata declared with user".into());
256    }
257    Ok(true)
258}
259
260impl UserAuthenticator {
261    pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
262        let success = match self {
263            UserAuthenticator::None => true,
264            UserAuthenticator::ClearText(text) => password == text,
265            UserAuthenticator::Md5WithSalt {
266                encrypted_password, ..
267            } => encrypted_password == password,
268            UserAuthenticator::OAuth {
269                metadata,
270                cluster_id,
271            } => {
272                let mut metadata = metadata.clone();
273                let jwks_url = metadata.remove("jwks_url").unwrap();
274                let issuer = metadata.remove("issuer").unwrap();
275                validate_jwt(
276                    &String::from_utf8_lossy(password),
277                    &jwks_url,
278                    &issuer,
279                    cluster_id,
280                    &metadata,
281                )
282                .await
283                .map_err(PsqlError::StartupError)?
284            }
285            UserAuthenticator::Ldap(user_name, hba_entry) => {
286                let ldap_auth = LdapAuthenticator::new(hba_entry)?;
287                // Convert password to string, defaulting to empty if not valid UTF-8
288                let password_str = String::from_utf8_lossy(password).into_owned();
289                ldap_auth.authenticate(user_name, &password_str).await?
290            }
291        };
292        if !success {
293            return Err(PsqlError::PasswordError);
294        }
295        Ok(())
296    }
297}
298
299/// Binds a Tcp or Unix listener at `addr`. Spawn a coroutine to serve every new connection.
300///
301/// Returns when the `shutdown` token is triggered.
302pub async fn pg_serve(
303    addr: &str,
304    tcp_keepalive: TcpKeepalive,
305    session_mgr: Arc<impl SessionManager>,
306    context: ConnectionContext,
307    shutdown: CancellationToken,
308) -> Result<(), BoxedError> {
309    let listener = Listener::bind(addr).await?;
310    tracing::info!(addr, "server started");
311
312    let acceptor_runtime = BackgroundShutdownRuntime::from({
313        let mut builder = tokio::runtime::Builder::new_multi_thread();
314        builder.worker_threads(1);
315        builder
316            .thread_name("rw-acceptor")
317            .enable_all()
318            .build()
319            .unwrap()
320    });
321
322    #[cfg(not(madsim))]
323    let worker_runtime = tokio::runtime::Handle::current();
324    #[cfg(madsim)]
325    let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
326    let session_mgr_clone = session_mgr.clone();
327    let f = async move {
328        loop {
329            let conn_ret = listener.accept(&tcp_keepalive).await;
330            match conn_ret {
331                Ok((stream, peer_addr)) => {
332                    tracing::info!(%peer_addr, "accept connection");
333                    worker_runtime.spawn(handle_connection(
334                        stream,
335                        session_mgr_clone.clone(),
336                        Arc::new(peer_addr),
337                        context.clone(),
338                    ));
339                }
340
341                Err(e) => {
342                    tracing::error!(error = %e.as_report(), "failed to accept connection",);
343                }
344            }
345        }
346    };
347    acceptor_runtime.spawn(f);
348
349    // Wait for the shutdown signal.
350    shutdown.cancelled().await;
351
352    // Stop accepting new connections.
353    drop(acceptor_runtime);
354    // Shutdown session manager, typically close all existing sessions.
355    session_mgr.shutdown().await;
356
357    Ok(())
358}
359
360pub async fn handle_connection<S, SM>(
361    stream: S,
362    session_mgr: Arc<SM>,
363    peer_addr: AddressRef,
364    context: ConnectionContext,
365) where
366    S: PgByteStream,
367    SM: SessionManager,
368{
369    PgProtocol::new(stream, session_mgr, peer_addr, context)
370        .run()
371        .await;
372}
373#[cfg(test)]
374mod tests {
375    use std::sync::Arc;
376    use std::time::Instant;
377
378    use bytes::Bytes;
379    use futures::StreamExt;
380    use futures::stream::BoxStream;
381    use risingwave_common::id::DatabaseId;
382    use risingwave_common::types::DataType;
383    use risingwave_common::util::tokio_util::sync::CancellationToken;
384    use risingwave_sqlparser::ast::Statement;
385    use tokio_postgres::NoTls;
386
387    use crate::error::PsqlResult;
388    use crate::memory_manager::MessageMemoryManager;
389    use crate::pg_field_descriptor::PgFieldDescriptor;
390    use crate::pg_message::TransactionStatus;
391    use crate::pg_protocol::ConnectionContext;
392    use crate::pg_response::{PgResponse, RowSetResult, StatementType};
393    use crate::pg_server::{
394        BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
395        UserAuthenticator, pg_serve,
396    };
397    use crate::types;
398    use crate::types::Row;
399
400    struct MockSessionManager {}
401    struct MockSession {}
402
403    impl SessionManager for MockSessionManager {
404        type Error = BoxedError;
405        type Session = MockSession;
406
407        fn create_dummy_session(
408            &self,
409            _database_id: DatabaseId,
410        ) -> Result<Arc<Self::Session>, Self::Error> {
411            unimplemented!()
412        }
413
414        fn connect(
415            &self,
416            _database: &str,
417            _user_name: &str,
418            _peer_addr: crate::net::AddressRef,
419        ) -> Result<Arc<Self::Session>, Self::Error> {
420            Ok(Arc::new(MockSession {}))
421        }
422
423        fn cancel_queries_in_session(&self, _session_id: SessionId) {
424            todo!()
425        }
426
427        fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
428            todo!()
429        }
430
431        fn end_session(&self, _session: &Self::Session) {}
432    }
433
434    impl Session for MockSession {
435        type Error = BoxedError;
436        type Portal = String;
437        type PreparedStatement = String;
438        type ValuesStream = BoxStream<'static, RowSetResult>;
439
440        async fn run_one_query(
441            self: Arc<Self>,
442            _stmt: Statement,
443            _format: types::Format,
444        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
445            Ok(PgResponse::builder(StatementType::SELECT)
446                .values(
447                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
448                        .boxed(),
449                    vec![
450                        // 1043 is the oid of varchar type.
451                        // -1 is the type len of varchar type.
452                        PgFieldDescriptor::new("".to_owned(), 1043, -1);
453                        1
454                    ],
455                )
456                .into())
457        }
458
459        async fn parse(
460            self: Arc<Self>,
461            _sql: Option<Statement>,
462            _params_types: Vec<Option<DataType>>,
463        ) -> Result<String, Self::Error> {
464            Ok(String::new())
465        }
466
467        fn bind(
468            self: Arc<Self>,
469            _prepare_statement: String,
470            _params: Vec<Option<Bytes>>,
471            _param_formats: Vec<types::Format>,
472            _result_formats: Vec<types::Format>,
473        ) -> Result<String, Self::Error> {
474            Ok(String::new())
475        }
476
477        async fn execute(
478            self: Arc<Self>,
479            _portal: String,
480        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
481            Ok(PgResponse::builder(StatementType::SELECT)
482                .values(
483                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
484                        .boxed(),
485                    vec![
486                    // 1043 is the oid of varchar type.
487                    // -1 is the type len of varchar type.
488                    PgFieldDescriptor::new("".to_owned(), 1043, -1);
489                    1
490                ],
491                )
492                .into())
493        }
494
495        fn describe_statement(
496            self: Arc<Self>,
497            _statement: String,
498        ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error> {
499            Ok((
500                vec![],
501                vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
502            ))
503        }
504
505        fn describe_portal(
506            self: Arc<Self>,
507            _portal: String,
508        ) -> Result<Vec<PgFieldDescriptor>, Self::Error> {
509            Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
510        }
511
512        fn user_authenticator(&self) -> &UserAuthenticator {
513            &UserAuthenticator::None
514        }
515
516        fn id(&self) -> SessionId {
517            (0, 0)
518        }
519
520        fn get_config(&self, key: &str) -> Result<String, Self::Error> {
521            match key {
522                "timezone" => Ok("UTC".to_owned()),
523                _ => Err(format!("Unknown config key: {key}").into()),
524            }
525        }
526
527        fn set_config(&self, _key: &str, _value: String) -> Result<String, Self::Error> {
528            Ok("".to_owned())
529        }
530
531        async fn next_notice(self: &Arc<Self>) -> String {
532            std::future::pending().await
533        }
534
535        fn transaction_status(&self) -> TransactionStatus {
536            TransactionStatus::Idle
537        }
538
539        fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
540            let exec_context = Arc::new(ExecContext {
541                running_sql: sql,
542                last_instant: Instant::now(),
543                last_idle_instant: Default::default(),
544            });
545            ExecContextGuard::new(exec_context)
546        }
547
548        fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
549            Ok(())
550        }
551    }
552
553    async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
554        let bind_addr = bind_addr.into();
555        let pg_config = pg_config.into();
556
557        let session_mgr = MockSessionManager {};
558        tokio::spawn(async move {
559            pg_serve(
560                &bind_addr,
561                socket2::TcpKeepalive::new(),
562                Arc::new(session_mgr),
563                ConnectionContext {
564                    tls_config: None,
565                    redact_sql_option_keywords: None,
566                    message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
567                        .into(),
568                },
569                CancellationToken::new(), // dummy
570            )
571            .await
572        });
573        // wait for server to start
574        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
575
576        // Connect to the database.
577        let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
578
579        // The connection object performs the actual communication with the database,
580        // so spawn it off to run on its own.
581        tokio::spawn(async move {
582            if let Err(e) = connection.await {
583                eprintln!("connection error: {}", e);
584            }
585        });
586
587        let rows = client
588            .simple_query("SELECT ''")
589            .await
590            .expect("Error executing query");
591        // Row + CommandComplete
592        assert_eq!(rows.len(), 2);
593
594        let rows = client
595            .query("SELECT ''", &[])
596            .await
597            .expect("Error executing query");
598        assert_eq!(rows.len(), 1);
599    }
600
601    #[tokio::test]
602    async fn test_query_tcp() {
603        do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
604    }
605
606    #[cfg(not(madsim))]
607    #[tokio::test]
608    async fn test_query_unix() {
609        let port: i16 = 10000;
610        let dir = tempfile::TempDir::new().unwrap();
611        let sock = dir.path().join(format!(".s.PGSQL.{port}"));
612
613        do_test_query(
614            format!("unix:{}", sock.to_str().unwrap()),
615            format!("host={} port={}", dir.path().to_str().unwrap(), port),
616        )
617        .await;
618    }
619
620    mod jwt_validation_tests {
621        use std::collections::HashMap;
622        use std::time::{SystemTime, UNIX_EPOCH};
623
624        use base64::Engine;
625        use jsonwebtoken::{Algorithm, EncodingKey, Header};
626        use rsa::pkcs1::EncodeRsaPrivateKey;
627        use rsa::traits::PublicKeyParts;
628        use rsa::{RsaPrivateKey, RsaPublicKey};
629        use serde_json::json;
630
631        use crate::pg_server::{Jwk, Jwks, validate_jwt_with_jwks};
632
633        fn create_test_rsa_keys() -> (RsaPrivateKey, RsaPublicKey) {
634            let mut rng = rand::thread_rng();
635            let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate a key");
636            let public_key = RsaPublicKey::from(&private_key);
637            (private_key, public_key)
638        }
639
640        fn create_test_jwks(public_key: &RsaPublicKey, kid: &str, alg: &str) -> Jwks {
641            let n = base64::engine::general_purpose::URL_SAFE_NO_PAD
642                .encode(public_key.n().to_bytes_be());
643            let e = base64::engine::general_purpose::URL_SAFE_NO_PAD
644                .encode(public_key.e().to_bytes_be());
645
646            Jwks {
647                keys: vec![Jwk {
648                    kid: kid.to_owned(),
649                    alg: alg.to_owned(),
650                    n,
651                    e,
652                }],
653            }
654        }
655
656        fn create_jwt_token(
657            private_key: &RsaPrivateKey,
658            kid: &str,
659            algorithm: Algorithm,
660            issuer: &str,
661            audience: Option<&str>,
662            exp: u64,
663            additional_claims: HashMap<String, serde_json::Value>,
664        ) -> String {
665            let mut header = Header::new(algorithm);
666            header.kid = Some(kid.to_owned());
667
668            let mut claims = json!({
669                "iss": issuer,
670                "exp": exp,
671            });
672
673            if let Some(aud) = audience {
674                claims["aud"] = json!(aud);
675            }
676
677            for (key, value) in additional_claims {
678                claims[key] = value;
679            }
680
681            let encoding_key = EncodingKey::from_rsa_pem(
682                private_key
683                    .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
684                    .unwrap()
685                    .as_bytes(),
686            )
687            .unwrap();
688
689            jsonwebtoken::encode(&header, &claims, &encoding_key).unwrap()
690        }
691
692        fn get_future_timestamp() -> u64 {
693            SystemTime::now()
694                .duration_since(UNIX_EPOCH)
695                .unwrap()
696                .as_secs()
697                + 3600 // 1 hour from now
698        }
699
700        fn get_past_timestamp() -> u64 {
701            SystemTime::now()
702                .duration_since(UNIX_EPOCH)
703                .unwrap()
704                .as_secs()
705                - 3600 // 1 hour ago
706        }
707
708        #[test]
709        fn test_jwt_with_invalid_audience() {
710            let (private_key, public_key) = create_test_rsa_keys();
711            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
712
713            let metadata = HashMap::new();
714
715            let jwt = create_jwt_token(
716                &private_key,
717                "test-kid",
718                Algorithm::RS256,
719                "https://test-issuer.com",
720                Some("urn:risingwave:cluster:wrong-cluster-id"),
721                get_future_timestamp(),
722                HashMap::new(),
723            );
724
725            let result = validate_jwt_with_jwks(
726                &jwt,
727                &jwks,
728                "https://test-issuer.com",
729                "test-cluster-id",
730                &metadata,
731            );
732
733            let error = result.unwrap_err();
734            assert!(error.to_string().contains("InvalidAudience"));
735        }
736
737        #[test]
738        fn test_jwt_with_missing_audience() {
739            let (private_key, public_key) = create_test_rsa_keys();
740            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
741
742            let metadata = HashMap::new();
743
744            let jwt = create_jwt_token(
745                &private_key,
746                "test-kid",
747                Algorithm::RS256,
748                "https://test-issuer.com",
749                None, // No audience claim
750                get_future_timestamp(),
751                HashMap::new(),
752            );
753
754            let result = validate_jwt_with_jwks(
755                &jwt,
756                &jwks,
757                "https://test-issuer.com",
758                "test-cluster-id",
759                &metadata,
760            );
761
762            let error = result.unwrap_err();
763            assert!(error.to_string().contains("Missing required claim: aud"));
764        }
765
766        #[test]
767        fn test_jwt_with_invalid_issuer() {
768            let (private_key, public_key) = create_test_rsa_keys();
769            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
770
771            let metadata = HashMap::new();
772
773            let jwt = create_jwt_token(
774                &private_key,
775                "test-kid",
776                Algorithm::RS256,
777                "https://wrong-issuer.com",
778                Some("urn:risingwave:cluster:test-cluster-id"),
779                get_future_timestamp(),
780                HashMap::new(),
781            );
782
783            let result = validate_jwt_with_jwks(
784                &jwt,
785                &jwks,
786                "https://test-issuer.com",
787                "test-cluster-id",
788                &metadata,
789            );
790
791            let error = result.unwrap_err();
792            assert!(error.to_string().contains("InvalidIssuer"));
793        }
794
795        #[test]
796        fn test_jwt_with_kid_not_found_in_jwks() {
797            let (private_key, public_key) = create_test_rsa_keys();
798            let jwks = create_test_jwks(&public_key, "different-kid", "RS256");
799
800            let metadata = HashMap::new();
801
802            let jwt = create_jwt_token(
803                &private_key,
804                "missing-kid",
805                Algorithm::RS256,
806                "https://test-issuer.com",
807                Some("urn:risingwave:cluster:test-cluster-id"),
808                get_future_timestamp(),
809                HashMap::new(),
810            );
811
812            let result = validate_jwt_with_jwks(
813                &jwt,
814                &jwks,
815                "https://test-issuer.com",
816                "test-cluster-id",
817                &metadata,
818            );
819
820            let error = result.unwrap_err();
821            assert!(
822                error
823                    .to_string()
824                    .contains("No matching key found in JWKS for kid: 'missing-kid'")
825            );
826        }
827
828        #[test]
829        fn test_jwt_with_expired_token() {
830            let (private_key, public_key) = create_test_rsa_keys();
831            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
832
833            let metadata = HashMap::new();
834
835            let jwt = create_jwt_token(
836                &private_key,
837                "test-kid",
838                Algorithm::RS256,
839                "https://test-issuer.com",
840                Some("urn:risingwave:cluster:test-cluster-id"),
841                get_past_timestamp(), // Expired token
842                HashMap::new(),
843            );
844
845            let result = validate_jwt_with_jwks(
846                &jwt,
847                &jwks,
848                "https://test-issuer.com",
849                "test-cluster-id",
850                &metadata,
851            );
852
853            let error = result.unwrap_err();
854            assert!(error.to_string().contains("ExpiredSignature"));
855        }
856
857        #[test]
858        fn test_jwt_with_invalid_signature() {
859            let (_, public_key) = create_test_rsa_keys();
860            let (wrong_private_key, _) = create_test_rsa_keys(); // Different key pair
861            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
862
863            let metadata = HashMap::new();
864
865            // Sign with wrong private key
866            let jwt = create_jwt_token(
867                &wrong_private_key,
868                "test-kid",
869                Algorithm::RS256,
870                "https://test-issuer.com",
871                Some("urn:risingwave:cluster:test-cluster-id"),
872                get_future_timestamp(),
873                HashMap::new(),
874            );
875
876            let result = validate_jwt_with_jwks(
877                &jwt,
878                &jwks,
879                "https://test-issuer.com",
880                "test-cluster-id",
881                &metadata,
882            );
883
884            let error = result.unwrap_err();
885            assert!(error.to_string().contains("InvalidSignature"));
886        }
887
888        #[test]
889        fn test_metadata_validation_success() {
890            let (private_key, public_key) = create_test_rsa_keys();
891            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
892
893            let mut metadata = HashMap::new();
894            metadata.insert("role".to_owned(), "admin".to_owned());
895            metadata.insert("department".to_owned(), "security".to_owned());
896
897            let mut claims = HashMap::new();
898            claims.insert("role".to_owned(), json!("admin"));
899            claims.insert("department".to_owned(), json!("security"));
900            claims.insert("extra_claim".to_owned(), json!("ignored")); // Extra claims are fine
901
902            let jwt = create_jwt_token(
903                &private_key,
904                "test-kid",
905                Algorithm::RS256,
906                "https://test-issuer.com",
907                Some("urn:risingwave:cluster:test-cluster-id"),
908                get_future_timestamp(),
909                claims,
910            );
911
912            let result = validate_jwt_with_jwks(
913                &jwt,
914                &jwks,
915                "https://test-issuer.com",
916                "test-cluster-id",
917                &metadata,
918            );
919
920            assert!(result.unwrap());
921        }
922
923        #[test]
924        fn test_metadata_validation_failure() {
925            let (private_key, public_key) = create_test_rsa_keys();
926            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
927
928            let mut metadata = HashMap::new();
929            metadata.insert("role".to_owned(), "admin".to_owned());
930            metadata.insert("department".to_owned(), "security".to_owned());
931
932            let mut claims = HashMap::new();
933            claims.insert("role".to_owned(), json!("user")); // Wrong role
934            claims.insert("department".to_owned(), json!("security"));
935
936            let jwt = create_jwt_token(
937                &private_key,
938                "test-kid",
939                Algorithm::RS256,
940                "https://test-issuer.com",
941                Some("urn:risingwave:cluster:test-cluster-id"),
942                get_future_timestamp(),
943                claims,
944            );
945
946            let result = validate_jwt_with_jwks(
947                &jwt,
948                &jwks,
949                "https://test-issuer.com",
950                "test-cluster-id",
951                &metadata,
952            );
953
954            let error = result.unwrap_err();
955            assert_eq!(
956                error.to_string(),
957                "metadata in jwt does not match with metadata declared with user"
958            );
959        }
960    }
961}