pgwire/
pg_server.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::config::HbaEntry;
25use risingwave_common::id::DatabaseId;
26use risingwave_common::types::DataType;
27use risingwave_common::util::runtime::BackgroundShutdownRuntime;
28use risingwave_common::util::tokio_util::sync::CancellationToken;
29use risingwave_sqlparser::ast::Statement;
30use serde::Deserialize;
31use thiserror_ext::AsReport;
32
33use crate::error::{PsqlError, PsqlResult};
34use crate::ldap_auth::LdapAuthenticator;
35use crate::net::{AddressRef, Listener, TcpKeepalive};
36use crate::pg_field_descriptor::PgFieldDescriptor;
37use crate::pg_message::TransactionStatus;
38use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
39use crate::pg_response::{PgResponse, ValuesStream};
40use crate::types::Format;
41
42pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
43type ProcessId = i32;
44type SecretKey = i32;
45pub type SessionId = (ProcessId, SecretKey);
46
47/// The interface for a database system behind pgwire protocol.
48/// We can mock it for testing purpose.
49pub trait SessionManager: Send + Sync + 'static {
50    type Error: Into<BoxedError>;
51    type Session: Session<Error = Self::Error>;
52
53    /// In the process of auto schema change, we need a dummy session to access
54    /// catalog information in frontend and build a replace plan for the table.
55    fn create_dummy_session(
56        &self,
57        database_id: DatabaseId,
58        user_id: u32,
59    ) -> Result<Arc<Self::Session>, Self::Error>;
60
61    fn connect(
62        &self,
63        database: &str,
64        user_name: &str,
65        peer_addr: AddressRef,
66    ) -> Result<Arc<Self::Session>, Self::Error>;
67
68    fn cancel_queries_in_session(&self, session_id: SessionId);
69
70    fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
71
72    fn end_session(&self, session: &Self::Session);
73
74    /// Run some cleanup tasks before the server shutdown.
75    fn shutdown(&self) -> impl Future<Output = ()> + Send {
76        async {}
77    }
78}
79
80/// A psql connection. Each connection binds with a database. Switching database will need to
81/// recreate another connection.
82pub trait Session: Send + Sync {
83    type Error: Into<BoxedError>;
84    type ValuesStream: ValuesStream;
85    type PreparedStatement: Send + Clone + 'static;
86    type Portal: Send + Clone + std::fmt::Display + 'static;
87
88    /// The str sql can not use the unparse from AST: There is some problem when dealing with create
89    /// view, see <https://github.com/risingwavelabs/risingwave/issues/6801>.
90    fn run_one_query(
91        self: Arc<Self>,
92        stmt: Statement,
93        format: Format,
94    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
95
96    fn parse(
97        self: Arc<Self>,
98        sql: Option<Statement>,
99        params_types: Vec<Option<DataType>>,
100    ) -> impl Future<Output = Result<Self::PreparedStatement, Self::Error>> + Send;
101
102    /// Receive the next notice message to send to the client.
103    ///
104    /// This function should be cancellation-safe.
105    fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
106
107    fn bind(
108        self: Arc<Self>,
109        prepare_statement: Self::PreparedStatement,
110        params: Vec<Option<Bytes>>,
111        param_formats: Vec<Format>,
112        result_formats: Vec<Format>,
113    ) -> Result<Self::Portal, Self::Error>;
114
115    fn execute(
116        self: Arc<Self>,
117        portal: Self::Portal,
118    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
119
120    fn describe_statement(
121        self: Arc<Self>,
122        prepare_statement: Self::PreparedStatement,
123    ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error>;
124
125    fn describe_portal(
126        self: Arc<Self>,
127        portal: Self::Portal,
128    ) -> Result<Vec<PgFieldDescriptor>, Self::Error>;
129
130    fn user_authenticator(&self) -> &UserAuthenticator;
131
132    fn id(&self) -> SessionId;
133
134    fn get_config(&self, key: &str) -> Result<String, Self::Error>;
135
136    fn set_config(&self, key: &str, value: String) -> Result<String, Self::Error>;
137
138    fn transaction_status(&self) -> TransactionStatus;
139
140    fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
141
142    fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
143}
144
145/// Each session could run different SQLs multiple times.
146/// `ExecContext` represents the lifetime of a running SQL in the current session.
147pub struct ExecContext {
148    pub running_sql: Arc<str>,
149    /// The instant of the running sql
150    pub last_instant: Instant,
151    /// A reference used to update when `ExecContext` is dropped
152    pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
153}
154
155/// `ExecContextGuard` holds a `Arc` pointer. Once `ExecContextGuard` is dropped,
156/// the inner `Arc<ExecContext>` should not be referred anymore, so that its `Weak` reference (used in `SessionImpl`) will be the same lifecycle of the running sql execution context.
157pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
158
159impl ExecContextGuard {
160    pub fn new(exec_context: Arc<ExecContext>) -> Self {
161        Self(exec_context)
162    }
163}
164
165impl Drop for ExecContext {
166    fn drop(&mut self) {
167        *self.last_idle_instant.lock() = Some(Instant::now());
168    }
169}
170
171#[derive(Debug, Clone)]
172pub enum UserAuthenticator {
173    // No need to authenticate.
174    None,
175    // raw password in clear-text form.
176    ClearText(Vec<u8>),
177    // password encrypted with random salt.
178    Md5WithSalt {
179        encrypted_password: Vec<u8>,
180        salt: [u8; 4],
181    },
182    OAuth {
183        metadata: HashMap<String, String>,
184        cluster_id: String,
185    },
186    Ldap(String, HbaEntry),
187}
188
189/// A JWK Set is a JSON object that represents a set of JWKs.
190/// The JSON object MUST have a "keys" member, with its value being an array of JWKs.
191/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-5> for more details.
192#[derive(Debug, Deserialize)]
193struct Jwks {
194    keys: Vec<Jwk>,
195}
196
197/// A JSON Web Key (JWK) is a JSON object that represents a cryptographic key.
198/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-4> for more details.
199#[derive(Debug, Deserialize)]
200struct Jwk {
201    kid: String, // Key ID
202    alg: String, // Algorithm
203    n: String,   // Modulus
204    e: String,   // Exponent
205}
206
207async fn validate_jwt(
208    jwt: &str,
209    jwks_url: &str,
210    issuer: &str,
211    cluster_id: &str,
212    metadata: &HashMap<String, String>,
213) -> Result<bool, BoxedError> {
214    let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
215    validate_jwt_with_jwks(jwt, &jwks, issuer, cluster_id, metadata)
216}
217
218fn audience_from_cluster_id(cluster_id: &str) -> String {
219    format!("urn:risingwave:cluster:{}", cluster_id)
220}
221
222fn validate_jwt_with_jwks(
223    jwt: &str,
224    jwks: &Jwks,
225    issuer: &str,
226    cluster_id: &str,
227    metadata: &HashMap<String, String>,
228) -> Result<bool, BoxedError> {
229    let header = decode_header(jwt)?;
230
231    // 1. Retrieve the kid from the header to find the right JWK in the JWK Set.
232    let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
233    let jwk = jwks
234        .keys
235        .iter()
236        .find(|k| k.kid == kid)
237        .ok_or(format!("No matching key found in JWKS for kid: '{}'", kid))?;
238
239    // 2. Check if the algorithms are matched.
240    if Algorithm::from_str(&jwk.alg)? != header.alg {
241        return Err("alg in jwt header does not match with alg in jwk".into());
242    }
243
244    // 3. Decode the JWT and validate the claims.
245    let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
246    let mut validation = Validation::new(header.alg);
247    validation.set_issuer(&[issuer]);
248    validation.set_audience(&[audience_from_cluster_id(cluster_id)]); // JWT 'aud' claim must match cluster_id
249    validation.set_required_spec_claims(&["exp", "iss", "aud"]);
250    let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
251
252    // 4. Check if the metadata in the token matches.
253    if !metadata.iter().all(
254        |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
255    ) {
256        return Err("metadata in jwt does not match with metadata declared with user".into());
257    }
258    Ok(true)
259}
260
261impl UserAuthenticator {
262    pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
263        let success = match self {
264            UserAuthenticator::None => true,
265            UserAuthenticator::ClearText(text) => password == text,
266            UserAuthenticator::Md5WithSalt {
267                encrypted_password, ..
268            } => encrypted_password == password,
269            UserAuthenticator::OAuth {
270                metadata,
271                cluster_id,
272            } => {
273                let mut metadata = metadata.clone();
274                let jwks_url = metadata.remove("jwks_url").unwrap();
275                let issuer = metadata.remove("issuer").unwrap();
276                validate_jwt(
277                    &String::from_utf8_lossy(password),
278                    &jwks_url,
279                    &issuer,
280                    cluster_id,
281                    &metadata,
282                )
283                .await
284                .map_err(PsqlError::StartupError)?
285            }
286            UserAuthenticator::Ldap(user_name, hba_entry) => {
287                let ldap_auth = LdapAuthenticator::new(hba_entry)?;
288                // Convert password to string, defaulting to empty if not valid UTF-8
289                let password_str = String::from_utf8_lossy(password).into_owned();
290                ldap_auth.authenticate(user_name, &password_str).await?
291            }
292        };
293        if !success {
294            return Err(PsqlError::PasswordError);
295        }
296        Ok(())
297    }
298}
299
300/// Binds a Tcp or Unix listener at `addr`. Spawn a coroutine to serve every new connection.
301///
302/// Returns when the `shutdown` token is triggered.
303pub async fn pg_serve(
304    addr: &str,
305    tcp_keepalive: TcpKeepalive,
306    session_mgr: Arc<impl SessionManager>,
307    context: ConnectionContext,
308    shutdown: CancellationToken,
309) -> Result<(), BoxedError> {
310    let listener = Listener::bind(addr).await?;
311    tracing::info!(addr, "server started");
312
313    let acceptor_runtime = BackgroundShutdownRuntime::from({
314        let mut builder = tokio::runtime::Builder::new_multi_thread();
315        builder.worker_threads(1);
316        builder
317            .thread_name("rw-acceptor")
318            .enable_all()
319            .build()
320            .unwrap()
321    });
322
323    #[cfg(not(madsim))]
324    let worker_runtime = tokio::runtime::Handle::current();
325    #[cfg(madsim)]
326    let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
327    let session_mgr_clone = session_mgr.clone();
328    let f = async move {
329        loop {
330            let conn_ret = listener.accept(&tcp_keepalive).await;
331            match conn_ret {
332                Ok((stream, peer_addr)) => {
333                    tracing::info!(%peer_addr, "accept connection");
334                    worker_runtime.spawn(handle_connection(
335                        stream,
336                        session_mgr_clone.clone(),
337                        Arc::new(peer_addr),
338                        context.clone(),
339                    ));
340                }
341
342                Err(e) => {
343                    tracing::error!(error = %e.as_report(), "failed to accept connection",);
344                }
345            }
346        }
347    };
348    acceptor_runtime.spawn(f);
349
350    // Wait for the shutdown signal.
351    shutdown.cancelled().await;
352
353    // Stop accepting new connections.
354    drop(acceptor_runtime);
355    // Shutdown session manager, typically close all existing sessions.
356    session_mgr.shutdown().await;
357
358    Ok(())
359}
360
361pub async fn handle_connection<S, SM>(
362    stream: S,
363    session_mgr: Arc<SM>,
364    peer_addr: AddressRef,
365    context: ConnectionContext,
366) where
367    S: PgByteStream,
368    SM: SessionManager,
369{
370    PgProtocol::new(stream, session_mgr, peer_addr, context)
371        .run()
372        .await;
373}
374#[cfg(test)]
375mod tests {
376    use std::sync::Arc;
377    use std::time::Instant;
378
379    use bytes::Bytes;
380    use futures::StreamExt;
381    use futures::stream::BoxStream;
382    use risingwave_common::id::DatabaseId;
383    use risingwave_common::types::DataType;
384    use risingwave_common::util::tokio_util::sync::CancellationToken;
385    use risingwave_sqlparser::ast::Statement;
386    use tokio_postgres::NoTls;
387
388    use crate::error::PsqlResult;
389    use crate::memory_manager::MessageMemoryManager;
390    use crate::pg_field_descriptor::PgFieldDescriptor;
391    use crate::pg_message::TransactionStatus;
392    use crate::pg_protocol::ConnectionContext;
393    use crate::pg_response::{PgResponse, RowSetResult, StatementType};
394    use crate::pg_server::{
395        BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
396        UserAuthenticator, pg_serve,
397    };
398    use crate::types;
399    use crate::types::Row;
400
401    struct MockSessionManager {}
402    struct MockSession {}
403
404    impl SessionManager for MockSessionManager {
405        type Error = BoxedError;
406        type Session = MockSession;
407
408        fn create_dummy_session(
409            &self,
410            _database_id: DatabaseId,
411            _user_name: u32,
412        ) -> Result<Arc<Self::Session>, Self::Error> {
413            unimplemented!()
414        }
415
416        fn connect(
417            &self,
418            _database: &str,
419            _user_name: &str,
420            _peer_addr: crate::net::AddressRef,
421        ) -> Result<Arc<Self::Session>, Self::Error> {
422            Ok(Arc::new(MockSession {}))
423        }
424
425        fn cancel_queries_in_session(&self, _session_id: SessionId) {
426            todo!()
427        }
428
429        fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
430            todo!()
431        }
432
433        fn end_session(&self, _session: &Self::Session) {}
434    }
435
436    impl Session for MockSession {
437        type Error = BoxedError;
438        type Portal = String;
439        type PreparedStatement = String;
440        type ValuesStream = BoxStream<'static, RowSetResult>;
441
442        async fn run_one_query(
443            self: Arc<Self>,
444            _stmt: Statement,
445            _format: types::Format,
446        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
447            Ok(PgResponse::builder(StatementType::SELECT)
448                .values(
449                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
450                        .boxed(),
451                    vec![
452                        // 1043 is the oid of varchar type.
453                        // -1 is the type len of varchar type.
454                        PgFieldDescriptor::new("".to_owned(), 1043, -1);
455                        1
456                    ],
457                )
458                .into())
459        }
460
461        async fn parse(
462            self: Arc<Self>,
463            _sql: Option<Statement>,
464            _params_types: Vec<Option<DataType>>,
465        ) -> Result<String, Self::Error> {
466            Ok(String::new())
467        }
468
469        fn bind(
470            self: Arc<Self>,
471            _prepare_statement: String,
472            _params: Vec<Option<Bytes>>,
473            _param_formats: Vec<types::Format>,
474            _result_formats: Vec<types::Format>,
475        ) -> Result<String, Self::Error> {
476            Ok(String::new())
477        }
478
479        async fn execute(
480            self: Arc<Self>,
481            _portal: String,
482        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
483            Ok(PgResponse::builder(StatementType::SELECT)
484                .values(
485                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
486                        .boxed(),
487                    vec![
488                    // 1043 is the oid of varchar type.
489                    // -1 is the type len of varchar type.
490                    PgFieldDescriptor::new("".to_owned(), 1043, -1);
491                    1
492                ],
493                )
494                .into())
495        }
496
497        fn describe_statement(
498            self: Arc<Self>,
499            _statement: String,
500        ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error> {
501            Ok((
502                vec![],
503                vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
504            ))
505        }
506
507        fn describe_portal(
508            self: Arc<Self>,
509            _portal: String,
510        ) -> Result<Vec<PgFieldDescriptor>, Self::Error> {
511            Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
512        }
513
514        fn user_authenticator(&self) -> &UserAuthenticator {
515            &UserAuthenticator::None
516        }
517
518        fn id(&self) -> SessionId {
519            (0, 0)
520        }
521
522        fn get_config(&self, key: &str) -> Result<String, Self::Error> {
523            match key {
524                "timezone" => Ok("UTC".to_owned()),
525                _ => Err(format!("Unknown config key: {key}").into()),
526            }
527        }
528
529        fn set_config(&self, _key: &str, _value: String) -> Result<String, Self::Error> {
530            Ok("".to_owned())
531        }
532
533        async fn next_notice(self: &Arc<Self>) -> String {
534            std::future::pending().await
535        }
536
537        fn transaction_status(&self) -> TransactionStatus {
538            TransactionStatus::Idle
539        }
540
541        fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
542            let exec_context = Arc::new(ExecContext {
543                running_sql: sql,
544                last_instant: Instant::now(),
545                last_idle_instant: Default::default(),
546            });
547            ExecContextGuard::new(exec_context)
548        }
549
550        fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
551            Ok(())
552        }
553    }
554
555    async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
556        let bind_addr = bind_addr.into();
557        let pg_config = pg_config.into();
558
559        let session_mgr = MockSessionManager {};
560        tokio::spawn(async move {
561            pg_serve(
562                &bind_addr,
563                socket2::TcpKeepalive::new(),
564                Arc::new(session_mgr),
565                ConnectionContext {
566                    tls_config: None,
567                    redact_sql_option_keywords: None,
568                    message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
569                        .into(),
570                },
571                CancellationToken::new(), // dummy
572            )
573            .await
574        });
575        // wait for server to start
576        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
577
578        // Connect to the database.
579        let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
580
581        // The connection object performs the actual communication with the database,
582        // so spawn it off to run on its own.
583        tokio::spawn(async move {
584            if let Err(e) = connection.await {
585                eprintln!("connection error: {}", e);
586            }
587        });
588
589        let rows = client
590            .simple_query("SELECT ''")
591            .await
592            .expect("Error executing query");
593        // Row + CommandComplete
594        assert_eq!(rows.len(), 2);
595
596        let rows = client
597            .query("SELECT ''", &[])
598            .await
599            .expect("Error executing query");
600        assert_eq!(rows.len(), 1);
601    }
602
603    #[tokio::test]
604    async fn test_query_tcp() {
605        do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
606    }
607
608    #[cfg(not(madsim))]
609    #[tokio::test]
610    async fn test_query_unix() {
611        let port: i16 = 10000;
612        let dir = tempfile::TempDir::new().unwrap();
613        let sock = dir.path().join(format!(".s.PGSQL.{port}"));
614
615        do_test_query(
616            format!("unix:{}", sock.to_str().unwrap()),
617            format!("host={} port={}", dir.path().to_str().unwrap(), port),
618        )
619        .await;
620    }
621
622    mod jwt_validation_tests {
623        use std::collections::HashMap;
624        use std::time::{SystemTime, UNIX_EPOCH};
625
626        use base64::Engine;
627        use jsonwebtoken::{Algorithm, EncodingKey, Header};
628        use rsa::pkcs1::EncodeRsaPrivateKey;
629        use rsa::traits::PublicKeyParts;
630        use rsa::{RsaPrivateKey, RsaPublicKey};
631        use serde_json::json;
632
633        use crate::pg_server::{Jwk, Jwks, validate_jwt_with_jwks};
634
635        fn create_test_rsa_keys() -> (RsaPrivateKey, RsaPublicKey) {
636            let mut rng = rand::thread_rng();
637            let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate a key");
638            let public_key = RsaPublicKey::from(&private_key);
639            (private_key, public_key)
640        }
641
642        fn create_test_jwks(public_key: &RsaPublicKey, kid: &str, alg: &str) -> Jwks {
643            let n = base64::engine::general_purpose::URL_SAFE_NO_PAD
644                .encode(public_key.n().to_bytes_be());
645            let e = base64::engine::general_purpose::URL_SAFE_NO_PAD
646                .encode(public_key.e().to_bytes_be());
647
648            Jwks {
649                keys: vec![Jwk {
650                    kid: kid.to_owned(),
651                    alg: alg.to_owned(),
652                    n,
653                    e,
654                }],
655            }
656        }
657
658        fn create_jwt_token(
659            private_key: &RsaPrivateKey,
660            kid: &str,
661            algorithm: Algorithm,
662            issuer: &str,
663            audience: Option<&str>,
664            exp: u64,
665            additional_claims: HashMap<String, serde_json::Value>,
666        ) -> String {
667            let mut header = Header::new(algorithm);
668            header.kid = Some(kid.to_owned());
669
670            let mut claims = json!({
671                "iss": issuer,
672                "exp": exp,
673            });
674
675            if let Some(aud) = audience {
676                claims["aud"] = json!(aud);
677            }
678
679            for (key, value) in additional_claims {
680                claims[key] = value;
681            }
682
683            let encoding_key = EncodingKey::from_rsa_pem(
684                private_key
685                    .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
686                    .unwrap()
687                    .as_bytes(),
688            )
689            .unwrap();
690
691            jsonwebtoken::encode(&header, &claims, &encoding_key).unwrap()
692        }
693
694        fn get_future_timestamp() -> u64 {
695            SystemTime::now()
696                .duration_since(UNIX_EPOCH)
697                .unwrap()
698                .as_secs()
699                + 3600 // 1 hour from now
700        }
701
702        fn get_past_timestamp() -> u64 {
703            SystemTime::now()
704                .duration_since(UNIX_EPOCH)
705                .unwrap()
706                .as_secs()
707                - 3600 // 1 hour ago
708        }
709
710        #[test]
711        fn test_jwt_with_invalid_audience() {
712            let (private_key, public_key) = create_test_rsa_keys();
713            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
714
715            let metadata = HashMap::new();
716
717            let jwt = create_jwt_token(
718                &private_key,
719                "test-kid",
720                Algorithm::RS256,
721                "https://test-issuer.com",
722                Some("urn:risingwave:cluster:wrong-cluster-id"),
723                get_future_timestamp(),
724                HashMap::new(),
725            );
726
727            let result = validate_jwt_with_jwks(
728                &jwt,
729                &jwks,
730                "https://test-issuer.com",
731                "test-cluster-id",
732                &metadata,
733            );
734
735            let error = result.unwrap_err();
736            assert!(error.to_string().contains("InvalidAudience"));
737        }
738
739        #[test]
740        fn test_jwt_with_missing_audience() {
741            let (private_key, public_key) = create_test_rsa_keys();
742            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
743
744            let metadata = HashMap::new();
745
746            let jwt = create_jwt_token(
747                &private_key,
748                "test-kid",
749                Algorithm::RS256,
750                "https://test-issuer.com",
751                None, // No audience claim
752                get_future_timestamp(),
753                HashMap::new(),
754            );
755
756            let result = validate_jwt_with_jwks(
757                &jwt,
758                &jwks,
759                "https://test-issuer.com",
760                "test-cluster-id",
761                &metadata,
762            );
763
764            let error = result.unwrap_err();
765            assert!(error.to_string().contains("Missing required claim: aud"));
766        }
767
768        #[test]
769        fn test_jwt_with_invalid_issuer() {
770            let (private_key, public_key) = create_test_rsa_keys();
771            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
772
773            let metadata = HashMap::new();
774
775            let jwt = create_jwt_token(
776                &private_key,
777                "test-kid",
778                Algorithm::RS256,
779                "https://wrong-issuer.com",
780                Some("urn:risingwave:cluster:test-cluster-id"),
781                get_future_timestamp(),
782                HashMap::new(),
783            );
784
785            let result = validate_jwt_with_jwks(
786                &jwt,
787                &jwks,
788                "https://test-issuer.com",
789                "test-cluster-id",
790                &metadata,
791            );
792
793            let error = result.unwrap_err();
794            assert!(error.to_string().contains("InvalidIssuer"));
795        }
796
797        #[test]
798        fn test_jwt_with_kid_not_found_in_jwks() {
799            let (private_key, public_key) = create_test_rsa_keys();
800            let jwks = create_test_jwks(&public_key, "different-kid", "RS256");
801
802            let metadata = HashMap::new();
803
804            let jwt = create_jwt_token(
805                &private_key,
806                "missing-kid",
807                Algorithm::RS256,
808                "https://test-issuer.com",
809                Some("urn:risingwave:cluster:test-cluster-id"),
810                get_future_timestamp(),
811                HashMap::new(),
812            );
813
814            let result = validate_jwt_with_jwks(
815                &jwt,
816                &jwks,
817                "https://test-issuer.com",
818                "test-cluster-id",
819                &metadata,
820            );
821
822            let error = result.unwrap_err();
823            assert!(
824                error
825                    .to_string()
826                    .contains("No matching key found in JWKS for kid: 'missing-kid'")
827            );
828        }
829
830        #[test]
831        fn test_jwt_with_expired_token() {
832            let (private_key, public_key) = create_test_rsa_keys();
833            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
834
835            let metadata = HashMap::new();
836
837            let jwt = create_jwt_token(
838                &private_key,
839                "test-kid",
840                Algorithm::RS256,
841                "https://test-issuer.com",
842                Some("urn:risingwave:cluster:test-cluster-id"),
843                get_past_timestamp(), // Expired token
844                HashMap::new(),
845            );
846
847            let result = validate_jwt_with_jwks(
848                &jwt,
849                &jwks,
850                "https://test-issuer.com",
851                "test-cluster-id",
852                &metadata,
853            );
854
855            let error = result.unwrap_err();
856            assert!(error.to_string().contains("ExpiredSignature"));
857        }
858
859        #[test]
860        fn test_jwt_with_invalid_signature() {
861            let (_, public_key) = create_test_rsa_keys();
862            let (wrong_private_key, _) = create_test_rsa_keys(); // Different key pair
863            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
864
865            let metadata = HashMap::new();
866
867            // Sign with wrong private key
868            let jwt = create_jwt_token(
869                &wrong_private_key,
870                "test-kid",
871                Algorithm::RS256,
872                "https://test-issuer.com",
873                Some("urn:risingwave:cluster:test-cluster-id"),
874                get_future_timestamp(),
875                HashMap::new(),
876            );
877
878            let result = validate_jwt_with_jwks(
879                &jwt,
880                &jwks,
881                "https://test-issuer.com",
882                "test-cluster-id",
883                &metadata,
884            );
885
886            let error = result.unwrap_err();
887            assert!(error.to_string().contains("InvalidSignature"));
888        }
889
890        #[test]
891        fn test_metadata_validation_success() {
892            let (private_key, public_key) = create_test_rsa_keys();
893            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
894
895            let mut metadata = HashMap::new();
896            metadata.insert("role".to_owned(), "admin".to_owned());
897            metadata.insert("department".to_owned(), "security".to_owned());
898
899            let mut claims = HashMap::new();
900            claims.insert("role".to_owned(), json!("admin"));
901            claims.insert("department".to_owned(), json!("security"));
902            claims.insert("extra_claim".to_owned(), json!("ignored")); // Extra claims are fine
903
904            let jwt = create_jwt_token(
905                &private_key,
906                "test-kid",
907                Algorithm::RS256,
908                "https://test-issuer.com",
909                Some("urn:risingwave:cluster:test-cluster-id"),
910                get_future_timestamp(),
911                claims,
912            );
913
914            let result = validate_jwt_with_jwks(
915                &jwt,
916                &jwks,
917                "https://test-issuer.com",
918                "test-cluster-id",
919                &metadata,
920            );
921
922            assert!(result.unwrap());
923        }
924
925        #[test]
926        fn test_metadata_validation_failure() {
927            let (private_key, public_key) = create_test_rsa_keys();
928            let jwks = create_test_jwks(&public_key, "test-kid", "RS256");
929
930            let mut metadata = HashMap::new();
931            metadata.insert("role".to_owned(), "admin".to_owned());
932            metadata.insert("department".to_owned(), "security".to_owned());
933
934            let mut claims = HashMap::new();
935            claims.insert("role".to_owned(), json!("user")); // Wrong role
936            claims.insert("department".to_owned(), json!("security"));
937
938            let jwt = create_jwt_token(
939                &private_key,
940                "test-kid",
941                Algorithm::RS256,
942                "https://test-issuer.com",
943                Some("urn:risingwave:cluster:test-cluster-id"),
944                get_future_timestamp(),
945                claims,
946            );
947
948            let result = validate_jwt_with_jwks(
949                &jwt,
950                &jwks,
951                "https://test-issuer.com",
952                "test-cluster-id",
953                &metadata,
954            );
955
956            let error = result.unwrap_err();
957            assert_eq!(
958                error.to_string(),
959                "metadata in jwt does not match with metadata declared with user"
960            );
961        }
962    }
963}