pgwire/
pg_server.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::config::HbaEntry;
25use risingwave_common::id::DatabaseId;
26use risingwave_common::types::DataType;
27use risingwave_common::util::runtime::BackgroundShutdownRuntime;
28use risingwave_common::util::tokio_util::sync::CancellationToken;
29use risingwave_sqlparser::ast::Statement;
30use serde::Deserialize;
31use thiserror_ext::AsReport;
32
33use crate::error::{PsqlError, PsqlResult};
34use crate::ldap_auth::LdapAuthenticator;
35use crate::net::{AddressRef, Listener, TcpKeepalive};
36use crate::pg_field_descriptor::PgFieldDescriptor;
37use crate::pg_message::TransactionStatus;
38use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
39use crate::pg_response::{PgResponse, ValuesStream};
40use crate::types::Format;
41
42pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
43type ProcessId = i32;
44type SecretKey = i32;
45pub type SessionId = (ProcessId, SecretKey);
46
47/// The interface for a database system behind pgwire protocol.
48/// We can mock it for testing purpose.
49pub trait SessionManager: Send + Sync + 'static {
50    type Session: Session;
51
52    /// In the process of auto schema change, we need a dummy session to access
53    /// catalog information in frontend and build a replace plan for the table.
54    fn create_dummy_session(
55        &self,
56        database_id: DatabaseId,
57        user_id: u32,
58    ) -> Result<Arc<Self::Session>, BoxedError>;
59
60    fn connect(
61        &self,
62        database: &str,
63        user_name: &str,
64        peer_addr: AddressRef,
65    ) -> Result<Arc<Self::Session>, BoxedError>;
66
67    fn cancel_queries_in_session(&self, session_id: SessionId);
68
69    fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
70
71    fn end_session(&self, session: &Self::Session);
72
73    /// Run some cleanup tasks before the server shutdown.
74    fn shutdown(&self) -> impl Future<Output = ()> + Send {
75        async {}
76    }
77}
78
79/// A psql connection. Each connection binds with a database. Switching database will need to
80/// recreate another connection.
81pub trait Session: Send + Sync {
82    type ValuesStream: ValuesStream;
83    type PreparedStatement: Send + Clone + 'static;
84    type Portal: Send + Clone + std::fmt::Display + 'static;
85
86    /// The str sql can not use the unparse from AST: There is some problem when dealing with create
87    /// view, see <https://github.com/risingwavelabs/risingwave/issues/6801>.
88    fn run_one_query(
89        self: Arc<Self>,
90        stmt: Statement,
91        format: Format,
92    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
93
94    fn parse(
95        self: Arc<Self>,
96        sql: Option<Statement>,
97        params_types: Vec<Option<DataType>>,
98    ) -> impl Future<Output = Result<Self::PreparedStatement, BoxedError>> + Send;
99
100    /// Receive the next notice message to send to the client.
101    ///
102    /// This function should be cancellation-safe.
103    fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
104
105    fn bind(
106        self: Arc<Self>,
107        prepare_statement: Self::PreparedStatement,
108        params: Vec<Option<Bytes>>,
109        param_formats: Vec<Format>,
110        result_formats: Vec<Format>,
111    ) -> Result<Self::Portal, BoxedError>;
112
113    fn execute(
114        self: Arc<Self>,
115        portal: Self::Portal,
116    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
117
118    fn describe_statement(
119        self: Arc<Self>,
120        prepare_statement: Self::PreparedStatement,
121    ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError>;
122
123    fn describe_portal(
124        self: Arc<Self>,
125        portal: Self::Portal,
126    ) -> Result<Vec<PgFieldDescriptor>, BoxedError>;
127
128    fn user_authenticator(&self) -> &UserAuthenticator;
129
130    fn id(&self) -> SessionId;
131
132    fn get_config(&self, key: &str) -> Result<String, BoxedError>;
133
134    fn set_config(&self, key: &str, value: String) -> Result<String, BoxedError>;
135
136    fn transaction_status(&self) -> TransactionStatus;
137
138    fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
139
140    fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
141}
142
143/// Each session could run different SQLs multiple times.
144/// `ExecContext` represents the lifetime of a running SQL in the current session.
145pub struct ExecContext {
146    pub running_sql: Arc<str>,
147    /// The instant of the running sql
148    pub last_instant: Instant,
149    /// A reference used to update when `ExecContext` is dropped
150    pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
151}
152
153/// `ExecContextGuard` holds a `Arc` pointer. Once `ExecContextGuard` is dropped,
154/// the inner `Arc<ExecContext>` should not be referred anymore, so that its `Weak` reference (used in `SessionImpl`) will be the same lifecycle of the running sql execution context.
155pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
156
157impl ExecContextGuard {
158    pub fn new(exec_context: Arc<ExecContext>) -> Self {
159        Self(exec_context)
160    }
161}
162
163impl Drop for ExecContext {
164    fn drop(&mut self) {
165        *self.last_idle_instant.lock() = Some(Instant::now());
166    }
167}
168
169#[derive(Debug, Clone)]
170pub enum UserAuthenticator {
171    // No need to authenticate.
172    None,
173    // raw password in clear-text form.
174    ClearText(Vec<u8>),
175    // password encrypted with random salt.
176    Md5WithSalt {
177        encrypted_password: Vec<u8>,
178        salt: [u8; 4],
179    },
180    OAuth {
181        metadata: HashMap<String, String>,
182        cluster_id: String,
183    },
184    Ldap(String, HbaEntry),
185}
186
187/// A JWK Set is a JSON object that represents a set of JWKs.
188/// The JSON object MUST have a "keys" member, with its value being an array of JWKs.
189/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-5> for more details.
190#[derive(Debug, Deserialize)]
191struct Jwks {
192    keys: Vec<Jwk>,
193}
194
195/// A JSON Web Key (JWK) is a JSON object that represents a cryptographic key.
196/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-4> for more details.
197#[derive(Debug, Deserialize)]
198struct Jwk {
199    kid: String, // Key ID
200    alg: String, // Algorithm
201    n: String,   // Modulus
202    e: String,   // Exponent
203}
204
205async fn validate_jwt(
206    jwt: &str,
207    jwks_url: &str,
208    issuer: &str,
209    cluster_id: &str,
210    metadata: &HashMap<String, String>,
211) -> Result<bool, BoxedError> {
212    let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
213    validate_jwt_with_jwks(jwt, &jwks, issuer, cluster_id, metadata)
214}
215
216fn audience_from_cluster_id(cluster_id: &str) -> String {
217    format!("urn:risingwave:cluster:{}", cluster_id)
218}
219
220fn validate_jwt_with_jwks(
221    jwt: &str,
222    jwks: &Jwks,
223    issuer: &str,
224    cluster_id: &str,
225    metadata: &HashMap<String, String>,
226) -> Result<bool, BoxedError> {
227    let header = decode_header(jwt)?;
228
229    // 1. Retrieve the kid from the header to find the right JWK in the JWK Set.
230    let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
231    let jwk = jwks
232        .keys
233        .iter()
234        .find(|k| k.kid == kid)
235        .ok_or(format!("No matching key found in JWKS for kid: '{}'", kid))?;
236
237    // 2. Check if the algorithms are matched.
238    if Algorithm::from_str(&jwk.alg)? != header.alg {
239        return Err("alg in jwt header does not match with alg in jwk".into());
240    }
241
242    // 3. Decode the JWT and validate the claims.
243    let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
244    let mut validation = Validation::new(header.alg);
245    validation.set_issuer(&[issuer]);
246    validation.set_audience(&[audience_from_cluster_id(cluster_id)]); // JWT 'aud' claim must match cluster_id
247    validation.set_required_spec_claims(&["exp", "iss", "aud"]);
248    let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
249
250    // 4. Check if the metadata in the token matches.
251    if !metadata.iter().all(
252        |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
253    ) {
254        return Err("metadata in jwt does not match with metadata declared with user".into());
255    }
256    Ok(true)
257}
258
259impl UserAuthenticator {
260    pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
261        let success = match self {
262            UserAuthenticator::None => true,
263            UserAuthenticator::ClearText(text) => password == text,
264            UserAuthenticator::Md5WithSalt {
265                encrypted_password, ..
266            } => encrypted_password == password,
267            UserAuthenticator::OAuth {
268                metadata,
269                cluster_id,
270            } => {
271                let mut metadata = metadata.clone();
272                let jwks_url = metadata.remove("jwks_url").unwrap();
273                let issuer = metadata.remove("issuer").unwrap();
274                validate_jwt(
275                    &String::from_utf8_lossy(password),
276                    &jwks_url,
277                    &issuer,
278                    cluster_id,
279                    &metadata,
280                )
281                .await
282                .map_err(PsqlError::StartupError)?
283            }
284            UserAuthenticator::Ldap(user_name, hba_entry) => {
285                let ldap_auth = LdapAuthenticator::new(hba_entry)?;
286                // Convert password to string, defaulting to empty if not valid UTF-8
287                let password_str = String::from_utf8_lossy(password).into_owned();
288                ldap_auth.authenticate(user_name, &password_str).await?
289            }
290        };
291        if !success {
292            return Err(PsqlError::PasswordError);
293        }
294        Ok(())
295    }
296}
297
298/// Binds a Tcp or Unix listener at `addr`. Spawn a coroutine to serve every new connection.
299///
300/// Returns when the `shutdown` token is triggered.
301pub async fn pg_serve(
302    addr: &str,
303    tcp_keepalive: TcpKeepalive,
304    session_mgr: Arc<impl SessionManager>,
305    context: ConnectionContext,
306    shutdown: CancellationToken,
307) -> Result<(), BoxedError> {
308    let listener = Listener::bind(addr).await?;
309    tracing::info!(addr, "server started");
310
311    let acceptor_runtime = BackgroundShutdownRuntime::from({
312        let mut builder = tokio::runtime::Builder::new_multi_thread();
313        builder.worker_threads(1);
314        builder
315            .thread_name("rw-acceptor")
316            .enable_all()
317            .build()
318            .unwrap()
319    });
320
321    #[cfg(not(madsim))]
322    let worker_runtime = tokio::runtime::Handle::current();
323    #[cfg(madsim)]
324    let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
325    let session_mgr_clone = session_mgr.clone();
326    let f = async move {
327        loop {
328            let conn_ret = listener.accept(&tcp_keepalive).await;
329            match conn_ret {
330                Ok((stream, peer_addr)) => {
331                    tracing::info!(%peer_addr, "accept connection");
332                    worker_runtime.spawn(handle_connection(
333                        stream,
334                        session_mgr_clone.clone(),
335                        Arc::new(peer_addr),
336                        context.clone(),
337                    ));
338                }
339
340                Err(e) => {
341                    tracing::error!(error = %e.as_report(), "failed to accept connection",);
342                }
343            }
344        }
345    };
346    acceptor_runtime.spawn(f);
347
348    // Wait for the shutdown signal.
349    shutdown.cancelled().await;
350
351    // Stop accepting new connections.
352    drop(acceptor_runtime);
353    // Shutdown session manager, typically close all existing sessions.
354    session_mgr.shutdown().await;
355
356    Ok(())
357}
358
359pub async fn handle_connection<S, SM>(
360    stream: S,
361    session_mgr: Arc<SM>,
362    peer_addr: AddressRef,
363    context: ConnectionContext,
364) where
365    S: PgByteStream,
366    SM: SessionManager,
367{
368    PgProtocol::new(stream, session_mgr, peer_addr, context)
369        .run()
370        .await;
371}
372#[cfg(test)]
373mod tests {
374    use std::error::Error;
375    use std::sync::Arc;
376    use std::time::Instant;
377
378    use bytes::Bytes;
379    use futures::StreamExt;
380    use futures::stream::BoxStream;
381    use risingwave_common::id::DatabaseId;
382    use risingwave_common::types::DataType;
383    use risingwave_common::util::tokio_util::sync::CancellationToken;
384    use risingwave_sqlparser::ast::Statement;
385    use tokio_postgres::NoTls;
386
387    use crate::error::PsqlResult;
388    use crate::memory_manager::MessageMemoryManager;
389    use crate::pg_field_descriptor::PgFieldDescriptor;
390    use crate::pg_message::TransactionStatus;
391    use crate::pg_protocol::ConnectionContext;
392    use crate::pg_response::{PgResponse, RowSetResult, StatementType};
393    use crate::pg_server::{
394        BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
395        UserAuthenticator, pg_serve,
396    };
397    use crate::types;
398    use crate::types::Row;
399
400    struct MockSessionManager {}
401    struct MockSession {}
402
403    impl SessionManager for MockSessionManager {
404        type Session = MockSession;
405
406        fn create_dummy_session(
407            &self,
408            _database_id: DatabaseId,
409            _user_name: u32,
410        ) -> Result<Arc<Self::Session>, BoxedError> {
411            unimplemented!()
412        }
413
414        fn connect(
415            &self,
416            _database: &str,
417            _user_name: &str,
418            _peer_addr: crate::net::AddressRef,
419        ) -> Result<Arc<Self::Session>, Box<dyn Error + Send + Sync>> {
420            Ok(Arc::new(MockSession {}))
421        }
422
423        fn cancel_queries_in_session(&self, _session_id: SessionId) {
424            todo!()
425        }
426
427        fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
428            todo!()
429        }
430
431        fn end_session(&self, _session: &Self::Session) {}
432    }
433
434    impl Session for MockSession {
435        type Portal = String;
436        type PreparedStatement = String;
437        type ValuesStream = BoxStream<'static, RowSetResult>;
438
439        async fn run_one_query(
440            self: Arc<Self>,
441            _stmt: Statement,
442            _format: types::Format,
443        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
444            Ok(PgResponse::builder(StatementType::SELECT)
445                .values(
446                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
447                        .boxed(),
448                    vec![
449                        // 1043 is the oid of varchar type.
450                        // -1 is the type len of varchar type.
451                        PgFieldDescriptor::new("".to_owned(), 1043, -1);
452                        1
453                    ],
454                )
455                .into())
456        }
457
458        async fn parse(
459            self: Arc<Self>,
460            _sql: Option<Statement>,
461            _params_types: Vec<Option<DataType>>,
462        ) -> Result<String, BoxedError> {
463            Ok(String::new())
464        }
465
466        fn bind(
467            self: Arc<Self>,
468            _prepare_statement: String,
469            _params: Vec<Option<Bytes>>,
470            _param_formats: Vec<types::Format>,
471            _result_formats: Vec<types::Format>,
472        ) -> Result<String, BoxedError> {
473            Ok(String::new())
474        }
475
476        async fn execute(
477            self: Arc<Self>,
478            _portal: String,
479        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
480            Ok(PgResponse::builder(StatementType::SELECT)
481                .values(
482                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
483                        .boxed(),
484                    vec![
485                    // 1043 is the oid of varchar type.
486                    // -1 is the type len of varchar type.
487                    PgFieldDescriptor::new("".to_owned(), 1043, -1);
488                    1
489                ],
490                )
491                .into())
492        }
493
494        fn describe_statement(
495            self: Arc<Self>,
496            _statement: String,
497        ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
498            Ok((
499                vec![],
500                vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
501            ))
502        }
503
504        fn describe_portal(
505            self: Arc<Self>,
506            _portal: String,
507        ) -> Result<Vec<PgFieldDescriptor>, BoxedError> {
508            Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
509        }
510
511        fn user_authenticator(&self) -> &UserAuthenticator {
512            &UserAuthenticator::None
513        }
514
515        fn id(&self) -> SessionId {
516            (0, 0)
517        }
518
519        fn get_config(&self, key: &str) -> Result<String, BoxedError> {
520            match key {
521                "timezone" => Ok("UTC".to_owned()),
522                _ => Err(format!("Unknown config key: {key}").into()),
523            }
524        }
525
526        fn set_config(&self, _key: &str, _value: String) -> Result<String, BoxedError> {
527            Ok("".to_owned())
528        }
529
530        async fn next_notice(self: &Arc<Self>) -> String {
531            std::future::pending().await
532        }
533
534        fn transaction_status(&self) -> TransactionStatus {
535            TransactionStatus::Idle
536        }
537
538        fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
539            let exec_context = Arc::new(ExecContext {
540                running_sql: sql,
541                last_instant: Instant::now(),
542                last_idle_instant: Default::default(),
543            });
544            ExecContextGuard::new(exec_context)
545        }
546
547        fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
548            Ok(())
549        }
550    }
551
552    async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
553        let bind_addr = bind_addr.into();
554        let pg_config = pg_config.into();
555
556        let session_mgr = MockSessionManager {};
557        tokio::spawn(async move {
558            pg_serve(
559                &bind_addr,
560                socket2::TcpKeepalive::new(),
561                Arc::new(session_mgr),
562                ConnectionContext {
563                    tls_config: None,
564                    redact_sql_option_keywords: None,
565                    message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
566                        .into(),
567                },
568                CancellationToken::new(), // dummy
569            )
570            .await
571        });
572        // wait for server to start
573        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
574
575        // Connect to the database.
576        let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
577
578        // The connection object performs the actual communication with the database,
579        // so spawn it off to run on its own.
580        tokio::spawn(async move {
581            if let Err(e) = connection.await {
582                eprintln!("connection error: {}", e);
583            }
584        });
585
586        let rows = client
587            .simple_query("SELECT ''")
588            .await
589            .expect("Error executing query");
590        // Row + CommandComplete
591        assert_eq!(rows.len(), 2);
592
593        let rows = client
594            .query("SELECT ''", &[])
595            .await
596            .expect("Error executing query");
597        assert_eq!(rows.len(), 1);
598    }
599
600    #[tokio::test]
601    async fn test_query_tcp() {
602        do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
603    }
604
605    #[cfg(not(madsim))]
606    #[tokio::test]
607    async fn test_query_unix() {
608        let port: i16 = 10000;
609        let dir = tempfile::TempDir::new().unwrap();
610        let sock = dir.path().join(format!(".s.PGSQL.{port}"));
611
612        do_test_query(
613            format!("unix:{}", sock.to_str().unwrap()),
614            format!("host={} port={}", dir.path().to_str().unwrap(), port),
615        )
616        .await;
617    }
618
619    mod jwt_validation_tests {
620        use std::collections::HashMap;
621        use std::time::{SystemTime, UNIX_EPOCH};
622
623        use base64::Engine;
624        use jsonwebtoken::{Algorithm, EncodingKey, Header};
625        use rsa::pkcs1::EncodeRsaPrivateKey;
626        use rsa::traits::PublicKeyParts;
627        use rsa::{RsaPrivateKey, RsaPublicKey};
628        use serde_json::json;
629
630        use crate::pg_server::{Jwk, Jwks, validate_jwt_with_jwks};
631
632        fn create_test_rsa_keys() -> (RsaPrivateKey, RsaPublicKey) {
633            let mut rng = rand::thread_rng();
634            let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate a key");
635            let public_key = RsaPublicKey::from(&private_key);
636            (private_key, public_key)
637        }
638
639        fn create_test_jwks(public_key: &RsaPublicKey, kid: &str, alg: &str) -> Jwks {
640            let n = base64::engine::general_purpose::URL_SAFE_NO_PAD
641                .encode(public_key.n().to_bytes_be());
642            let e = base64::engine::general_purpose::URL_SAFE_NO_PAD
643                .encode(public_key.e().to_bytes_be());
644
645            Jwks {
646                keys: vec![Jwk {
647                    kid: kid.to_owned(),
648                    alg: alg.to_owned(),
649                    n,
650                    e,
651                }],
652            }
653        }
654
655        fn create_jwt_token(
656            private_key: &RsaPrivateKey,
657            kid: &str,
658            algorithm: Algorithm,
659            issuer: &str,
660            audience: Option<&str>,
661            exp: u64,
662            additional_claims: HashMap<String, serde_json::Value>,
663        ) -> String {
664            let mut header = Header::new(algorithm);
665            header.kid = Some(kid.to_owned());
666
667            let mut claims = json!({
668                "iss": issuer,
669                "exp": exp,
670            });
671
672            if let Some(aud) = audience {
673                claims["aud"] = json!(aud);
674            }
675
676            for (key, value) in additional_claims {
677                claims[key] = value;
678            }
679
680            let encoding_key = EncodingKey::from_rsa_pem(
681                private_key
682                    .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
683                    .unwrap()
684                    .as_bytes(),
685            )
686            .unwrap();
687
688            jsonwebtoken::encode(&header, &claims, &encoding_key).unwrap()
689        }
690
691        fn get_future_timestamp() -> u64 {
692            SystemTime::now()
693                .duration_since(UNIX_EPOCH)
694                .unwrap()
695                .as_secs()
696                + 3600 // 1 hour from now
697        }
698
699        fn get_past_timestamp() -> u64 {
700            SystemTime::now()
701                .duration_since(UNIX_EPOCH)
702                .unwrap()
703                .as_secs()
704                - 3600 // 1 hour ago
705        }
706
707        #[test]
708        fn test_jwt_with_invalid_audience() {
709            let (private_key, public_key) = create_test_rsa_keys();
710            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
711
712            let metadata = HashMap::new();
713
714            let jwt = create_jwt_token(
715                &private_key,
716                "test-kid",
717                Algorithm::RS256,
718                "https://test-issuer.com",
719                Some("urn:risingwave:cluster:wrong-cluster-id"),
720                get_future_timestamp(),
721                HashMap::new(),
722            );
723
724            let result = validate_jwt_with_jwks(
725                &jwt,
726                &jwks,
727                "https://test-issuer.com",
728                "test-cluster-id",
729                &metadata,
730            );
731
732            let error = result.unwrap_err();
733            assert!(error.to_string().contains("InvalidAudience"));
734        }
735
736        #[test]
737        fn test_jwt_with_missing_audience() {
738            let (private_key, public_key) = create_test_rsa_keys();
739            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
740
741            let metadata = HashMap::new();
742
743            let jwt = create_jwt_token(
744                &private_key,
745                "test-kid",
746                Algorithm::RS256,
747                "https://test-issuer.com",
748                None, // No audience claim
749                get_future_timestamp(),
750                HashMap::new(),
751            );
752
753            let result = validate_jwt_with_jwks(
754                &jwt,
755                &jwks,
756                "https://test-issuer.com",
757                "test-cluster-id",
758                &metadata,
759            );
760
761            let error = result.unwrap_err();
762            assert!(error.to_string().contains("Missing required claim: aud"));
763        }
764
765        #[test]
766        fn test_jwt_with_invalid_issuer() {
767            let (private_key, public_key) = create_test_rsa_keys();
768            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
769
770            let metadata = HashMap::new();
771
772            let jwt = create_jwt_token(
773                &private_key,
774                "test-kid",
775                Algorithm::RS256,
776                "https://wrong-issuer.com",
777                Some("urn:risingwave:cluster:test-cluster-id"),
778                get_future_timestamp(),
779                HashMap::new(),
780            );
781
782            let result = validate_jwt_with_jwks(
783                &jwt,
784                &jwks,
785                "https://test-issuer.com",
786                "test-cluster-id",
787                &metadata,
788            );
789
790            let error = result.unwrap_err();
791            assert!(error.to_string().contains("InvalidIssuer"));
792        }
793
794        #[test]
795        fn test_jwt_with_kid_not_found_in_jwks() {
796            let (private_key, public_key) = create_test_rsa_keys();
797            let jwks = create_test_jwks(&public_key, "different-kid", "RS256");
798
799            let metadata = HashMap::new();
800
801            let jwt = create_jwt_token(
802                &private_key,
803                "missing-kid",
804                Algorithm::RS256,
805                "https://test-issuer.com",
806                Some("urn:risingwave:cluster:test-cluster-id"),
807                get_future_timestamp(),
808                HashMap::new(),
809            );
810
811            let result = validate_jwt_with_jwks(
812                &jwt,
813                &jwks,
814                "https://test-issuer.com",
815                "test-cluster-id",
816                &metadata,
817            );
818
819            let error = result.unwrap_err();
820            assert!(
821                error
822                    .to_string()
823                    .contains("No matching key found in JWKS for kid: 'missing-kid'")
824            );
825        }
826
827        #[test]
828        fn test_jwt_with_expired_token() {
829            let (private_key, public_key) = create_test_rsa_keys();
830            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
831
832            let metadata = HashMap::new();
833
834            let jwt = create_jwt_token(
835                &private_key,
836                "test-kid",
837                Algorithm::RS256,
838                "https://test-issuer.com",
839                Some("urn:risingwave:cluster:test-cluster-id"),
840                get_past_timestamp(), // Expired token
841                HashMap::new(),
842            );
843
844            let result = validate_jwt_with_jwks(
845                &jwt,
846                &jwks,
847                "https://test-issuer.com",
848                "test-cluster-id",
849                &metadata,
850            );
851
852            let error = result.unwrap_err();
853            assert!(error.to_string().contains("ExpiredSignature"));
854        }
855
856        #[test]
857        fn test_jwt_with_invalid_signature() {
858            let (_, public_key) = create_test_rsa_keys();
859            let (wrong_private_key, _) = create_test_rsa_keys(); // Different key pair
860            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
861
862            let metadata = HashMap::new();
863
864            // Sign with wrong private key
865            let jwt = create_jwt_token(
866                &wrong_private_key,
867                "test-kid",
868                Algorithm::RS256,
869                "https://test-issuer.com",
870                Some("urn:risingwave:cluster:test-cluster-id"),
871                get_future_timestamp(),
872                HashMap::new(),
873            );
874
875            let result = validate_jwt_with_jwks(
876                &jwt,
877                &jwks,
878                "https://test-issuer.com",
879                "test-cluster-id",
880                &metadata,
881            );
882
883            let error = result.unwrap_err();
884            assert!(error.to_string().contains("InvalidSignature"));
885        }
886
887        #[test]
888        fn test_metadata_validation_success() {
889            let (private_key, public_key) = create_test_rsa_keys();
890            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
891
892            let mut metadata = HashMap::new();
893            metadata.insert("role".to_owned(), "admin".to_owned());
894            metadata.insert("department".to_owned(), "security".to_owned());
895
896            let mut claims = HashMap::new();
897            claims.insert("role".to_owned(), json!("admin"));
898            claims.insert("department".to_owned(), json!("security"));
899            claims.insert("extra_claim".to_owned(), json!("ignored")); // Extra claims are fine
900
901            let jwt = create_jwt_token(
902                &private_key,
903                "test-kid",
904                Algorithm::RS256,
905                "https://test-issuer.com",
906                Some("urn:risingwave:cluster:test-cluster-id"),
907                get_future_timestamp(),
908                claims,
909            );
910
911            let result = validate_jwt_with_jwks(
912                &jwt,
913                &jwks,
914                "https://test-issuer.com",
915                "test-cluster-id",
916                &metadata,
917            );
918
919            assert!(result.unwrap());
920        }
921
922        #[test]
923        fn test_metadata_validation_failure() {
924            let (private_key, public_key) = create_test_rsa_keys();
925            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
926
927            let mut metadata = HashMap::new();
928            metadata.insert("role".to_owned(), "admin".to_owned());
929            metadata.insert("department".to_owned(), "security".to_owned());
930
931            let mut claims = HashMap::new();
932            claims.insert("role".to_owned(), json!("user")); // Wrong role
933            claims.insert("department".to_owned(), json!("security"));
934
935            let jwt = create_jwt_token(
936                &private_key,
937                "test-kid",
938                Algorithm::RS256,
939                "https://test-issuer.com",
940                Some("urn:risingwave:cluster:test-cluster-id"),
941                get_future_timestamp(),
942                claims,
943            );
944
945            let result = validate_jwt_with_jwks(
946                &jwt,
947                &jwks,
948                "https://test-issuer.com",
949                "test-cluster-id",
950                &metadata,
951            );
952
953            let error = result.unwrap_err();
954            assert_eq!(
955                error.to_string(),
956                "metadata in jwt does not match with metadata declared with user"
957            );
958        }
959    }
960}