pgwire/
pg_server.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::config::HbaEntry;
25use risingwave_common::id::DatabaseId;
26use risingwave_common::types::DataType;
27use risingwave_common::util::runtime::BackgroundShutdownRuntime;
28use risingwave_common::util::tokio_util::sync::CancellationToken;
29use risingwave_sqlparser::ast::Statement;
30use serde::Deserialize;
31use thiserror_ext::AsReport;
32
33use crate::error::{PsqlError, PsqlResult};
34use crate::ldap_auth::LdapAuthenticator;
35use crate::net::{AddressRef, Listener, TcpKeepalive};
36use crate::pg_field_descriptor::PgFieldDescriptor;
37use crate::pg_message::TransactionStatus;
38use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
39use crate::pg_response::{PgResponse, ValuesStream};
40use crate::types::Format;
41
42pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
43type ProcessId = i32;
44type SecretKey = i32;
45pub type SessionId = (ProcessId, SecretKey);
46
47/// The interface for a database system behind pgwire protocol.
48/// We can mock it for testing purpose.
49pub trait SessionManager: Send + Sync + 'static {
50    type Error: Into<BoxedError>;
51    type Session: Session<Error = Self::Error>;
52
53    /// In the process of auto schema change, we need a dummy session to access
54    /// catalog information in frontend and build a replace plan for the table.
55    fn create_dummy_session(
56        &self,
57        database_id: DatabaseId,
58    ) -> Result<Arc<Self::Session>, Self::Error>;
59
60    fn connect(
61        &self,
62        database: &str,
63        user_name: &str,
64        peer_addr: AddressRef,
65    ) -> Result<Arc<Self::Session>, Self::Error>;
66
67    fn cancel_queries_in_session(&self, session_id: SessionId);
68
69    fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
70
71    fn end_session(&self, session: &Self::Session);
72
73    /// Run some cleanup tasks before the server shutdown.
74    fn shutdown(&self) -> impl Future<Output = ()> + Send {
75        async {}
76    }
77}
78
79/// A psql connection. Each connection binds with a database. Switching database will need to
80/// recreate another connection.
81pub trait Session: Send + Sync {
82    type Error: Into<BoxedError>;
83    type ValuesStream: ValuesStream;
84    type PreparedStatement: Send + Clone + 'static;
85    type Portal: Send + Clone + std::fmt::Display + 'static;
86
87    /// The str sql can not use the unparse from AST: There is some problem when dealing with create
88    /// view, see <https://github.com/risingwavelabs/risingwave/issues/6801>.
89    fn run_one_query(
90        self: Arc<Self>,
91        stmt: Statement,
92        format: Format,
93    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
94
95    fn parse(
96        self: Arc<Self>,
97        sql: Option<Statement>,
98        params_types: Vec<Option<DataType>>,
99    ) -> impl Future<Output = Result<Self::PreparedStatement, Self::Error>> + Send;
100
101    /// Receive the next notice message to send to the client.
102    ///
103    /// This function should be cancellation-safe.
104    fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
105
106    fn bind(
107        self: Arc<Self>,
108        prepare_statement: Self::PreparedStatement,
109        params: Vec<Option<Bytes>>,
110        param_formats: Vec<Format>,
111        result_formats: Vec<Format>,
112    ) -> Result<Self::Portal, Self::Error>;
113
114    fn execute(
115        self: Arc<Self>,
116        portal: Self::Portal,
117    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, Self::Error>> + Send;
118
119    fn describe_statement(
120        self: Arc<Self>,
121        prepare_statement: Self::PreparedStatement,
122    ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error>;
123
124    fn describe_portal(
125        self: Arc<Self>,
126        portal: Self::Portal,
127    ) -> Result<Vec<PgFieldDescriptor>, Self::Error>;
128
129    fn user_authenticator(&self) -> &UserAuthenticator;
130
131    fn id(&self) -> SessionId;
132
133    fn get_config(&self, key: &str) -> Result<String, Self::Error>;
134
135    fn set_config(&self, key: &str, value: String) -> Result<String, Self::Error>;
136
137    fn transaction_status(&self) -> TransactionStatus;
138
139    fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
140
141    fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
142
143    fn user(&self) -> String;
144}
145
146/// Each session could run different SQLs multiple times.
147/// `ExecContext` represents the lifetime of a running SQL in the current session.
148pub struct ExecContext {
149    pub running_sql: Arc<str>,
150    /// The instant of the running sql
151    pub last_instant: Instant,
152    /// A reference used to update when `ExecContext` is dropped
153    pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
154}
155
156/// `ExecContextGuard` holds a `Arc` pointer. Once `ExecContextGuard` is dropped,
157/// the inner `Arc<ExecContext>` should not be referred anymore, so that its `Weak` reference (used in `SessionImpl`) will be the same lifecycle of the running sql execution context.
158pub struct ExecContextGuard(#[expect(dead_code)] Arc<ExecContext>);
159
160impl ExecContextGuard {
161    pub fn new(exec_context: Arc<ExecContext>) -> Self {
162        Self(exec_context)
163    }
164}
165
166impl Drop for ExecContext {
167    fn drop(&mut self) {
168        *self.last_idle_instant.lock() = Some(Instant::now());
169    }
170}
171
172#[derive(Debug, Clone)]
173pub enum UserAuthenticator {
174    // No need to authenticate.
175    None,
176    // raw password in clear-text form.
177    ClearText(Vec<u8>),
178    // password encrypted with random salt.
179    Md5WithSalt {
180        encrypted_password: Vec<u8>,
181        salt: [u8; 4],
182    },
183    OAuth {
184        metadata: HashMap<String, String>,
185        cluster_id: String,
186    },
187    Ldap(String, HbaEntry),
188}
189
190/// A JWK Set is a JSON object that represents a set of JWKs.
191/// The JSON object MUST have a "keys" member, with its value being an array of JWKs.
192/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-5> for more details.
193#[derive(Debug, Deserialize)]
194struct Jwks {
195    keys: Vec<Jwk>,
196}
197
198/// A JSON Web Key (JWK) is a JSON object that represents a cryptographic key.
199/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-4> for more details.
200#[derive(Debug, Deserialize)]
201struct Jwk {
202    kid: String,         // Key ID
203    alg: Option<String>, // Algorithm (OPTIONAL per RFC 7517 section 4.4)
204    n: String,           // Modulus
205    e: String,           // Exponent
206}
207
208/// Algorithms we accept for JWT signature verification.
209///
210/// Restricted to RSA-family algorithms because the only `DecodingKey` we build
211/// is from RSA components (`n`, `e`). Pinning the algorithm to a server-side
212/// allow-list also prevents the classic alg-confusion attack: a token with
213/// `alg: "none"` (no signature) or `alg: "HS256"` forged using the RSA public
214/// key as the HMAC secret cannot select a verification algorithm outside this
215/// set.
216const ALLOWED_JWT_ALGORITHMS: &[Algorithm] = &[
217    Algorithm::RS256,
218    Algorithm::RS384,
219    Algorithm::RS512,
220    Algorithm::PS256,
221    Algorithm::PS384,
222    Algorithm::PS512,
223];
224
225async fn validate_jwt(
226    jwt: &str,
227    jwks_url: &str,
228    issuer: &str,
229    cluster_id: &str,
230    metadata: &HashMap<String, String>,
231) -> Result<bool, BoxedError> {
232    let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
233    validate_jwt_with_jwks(jwt, &jwks, issuer, cluster_id, metadata)
234}
235
236fn audience_from_cluster_id(cluster_id: &str) -> String {
237    format!("urn:risingwave:cluster:{}", cluster_id)
238}
239
240fn validate_jwt_with_jwks(
241    jwt: &str,
242    jwks: &Jwks,
243    issuer: &str,
244    cluster_id: &str,
245    metadata: &HashMap<String, String>,
246) -> Result<bool, BoxedError> {
247    let header = decode_header(jwt)?;
248
249    // 1. Retrieve the kid from the header to find the right JWK in the JWK Set.
250    let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
251    let jwk = jwks
252        .keys
253        .iter()
254        .find(|k| k.kid == kid)
255        .ok_or(format!("No matching key found in JWKS for kid: '{}'", kid))?;
256
257    // 2. Decide which algorithm to use.
258    //
259    // Per RFC 7517 §4.4 the JWK `alg` member is OPTIONAL. When the JWK pins an
260    // `alg`, the JWT header MUST match it; when it doesn't, we fall back to
261    // the header's `alg` but only after checking it against a server-side
262    // allow-list. The allow-list is what ultimately blocks alg-confusion: an
263    // attacker-chosen `alg` from the token header alone must never be trusted
264    // to select the verification algorithm.
265    let alg = match jwk.alg.as_deref() {
266        Some(jwk_alg) => {
267            let jwk_alg = Algorithm::from_str(jwk_alg)?;
268            if jwk_alg != header.alg {
269                return Err("alg in jwt header does not match with alg in jwk".into());
270            }
271            jwk_alg
272        }
273        None => header.alg,
274    };
275    if !ALLOWED_JWT_ALGORITHMS.contains(&alg) {
276        return Err(format!("JWT alg {:?} is not allowed", alg).into());
277    }
278
279    // 3. Decode the JWT and validate the claims.
280    let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
281    let mut validation = Validation::new(alg);
282    validation.set_issuer(&[issuer]);
283    validation.set_audience(&[audience_from_cluster_id(cluster_id)]); // JWT 'aud' claim must match cluster_id
284    validation.set_required_spec_claims(&["exp", "iss", "aud"]);
285    let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
286
287    // 4. Check if the metadata in the token matches.
288    if !metadata.iter().all(
289        |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
290    ) {
291        return Err("metadata in jwt does not match with metadata declared with user".into());
292    }
293    Ok(true)
294}
295
296impl UserAuthenticator {
297    pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
298        let success = match self {
299            UserAuthenticator::None => true,
300            UserAuthenticator::ClearText(text) => password == text,
301            UserAuthenticator::Md5WithSalt {
302                encrypted_password, ..
303            } => encrypted_password == password,
304            UserAuthenticator::OAuth {
305                metadata,
306                cluster_id,
307            } => {
308                let mut metadata = metadata.clone();
309                let jwks_url = metadata.remove("jwks_url").unwrap();
310                let issuer = metadata.remove("issuer").unwrap();
311                validate_jwt(
312                    &String::from_utf8_lossy(password),
313                    &jwks_url,
314                    &issuer,
315                    cluster_id,
316                    &metadata,
317                )
318                .await
319                .map_err(PsqlError::StartupError)?
320            }
321            UserAuthenticator::Ldap(user_name, hba_entry) => {
322                let ldap_auth = LdapAuthenticator::new(hba_entry)?;
323                // Convert password to string, defaulting to empty if not valid UTF-8
324                let password_str = String::from_utf8_lossy(password).into_owned();
325                ldap_auth.authenticate(user_name, &password_str).await?
326            }
327        };
328        if !success {
329            return Err(PsqlError::PasswordError);
330        }
331        Ok(())
332    }
333}
334
335/// Binds a Tcp or Unix listener at `addr`. Spawn a coroutine to serve every new connection.
336///
337/// Returns when the `shutdown` token is triggered.
338pub async fn pg_serve(
339    addr: &str,
340    tcp_keepalive: TcpKeepalive,
341    session_mgr: Arc<impl SessionManager>,
342    context: ConnectionContext,
343    shutdown: CancellationToken,
344) -> Result<(), BoxedError> {
345    let listener = Listener::bind(addr).await?;
346    tracing::info!(addr, "server started");
347
348    let acceptor_runtime = BackgroundShutdownRuntime::from({
349        let mut builder = tokio::runtime::Builder::new_multi_thread();
350        builder.worker_threads(1);
351        builder
352            .thread_name("rw-acceptor")
353            .enable_all()
354            .build()
355            .unwrap()
356    });
357
358    #[cfg(not(madsim))]
359    let worker_runtime = tokio::runtime::Handle::current();
360    #[cfg(madsim)]
361    let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
362    let session_mgr_clone = session_mgr.clone();
363    let f = async move {
364        loop {
365            let conn_ret = listener.accept(&tcp_keepalive).await;
366            match conn_ret {
367                Ok((stream, peer_addr)) => {
368                    tracing::info!(%peer_addr, "accept connection");
369                    worker_runtime.spawn(handle_connection(
370                        stream,
371                        session_mgr_clone.clone(),
372                        Arc::new(peer_addr),
373                        context.clone(),
374                    ));
375                }
376
377                Err(e) => {
378                    tracing::error!(error = %e.as_report(), "failed to accept connection",);
379                }
380            }
381        }
382    };
383    acceptor_runtime.spawn(f);
384
385    // Wait for the shutdown signal.
386    shutdown.cancelled().await;
387
388    // Stop accepting new connections.
389    drop(acceptor_runtime);
390    // Shutdown session manager, typically close all existing sessions.
391    session_mgr.shutdown().await;
392
393    Ok(())
394}
395
396pub async fn handle_connection<S, SM>(
397    stream: S,
398    session_mgr: Arc<SM>,
399    peer_addr: AddressRef,
400    context: ConnectionContext,
401) where
402    S: PgByteStream,
403    SM: SessionManager,
404{
405    PgProtocol::new(stream, session_mgr, peer_addr, context)
406        .run()
407        .await;
408}
409#[cfg(test)]
410mod tests {
411    use std::sync::Arc;
412    use std::time::Instant;
413
414    use bytes::Bytes;
415    use futures::StreamExt;
416    use futures::stream::BoxStream;
417    use risingwave_common::id::DatabaseId;
418    use risingwave_common::types::DataType;
419    use risingwave_common::util::tokio_util::sync::CancellationToken;
420    use risingwave_sqlparser::ast::Statement;
421    use tokio_postgres::NoTls;
422
423    use crate::error::PsqlResult;
424    use crate::memory_manager::MessageMemoryManager;
425    use crate::pg_field_descriptor::PgFieldDescriptor;
426    use crate::pg_message::TransactionStatus;
427    use crate::pg_protocol::ConnectionContext;
428    use crate::pg_response::{PgResponse, RowSetResult, StatementType};
429    use crate::pg_server::{
430        BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
431        UserAuthenticator, pg_serve,
432    };
433    use crate::types;
434    use crate::types::Row;
435
436    struct MockSessionManager {}
437    struct MockSession {}
438
439    impl SessionManager for MockSessionManager {
440        type Error = BoxedError;
441        type Session = MockSession;
442
443        fn create_dummy_session(
444            &self,
445            _database_id: DatabaseId,
446        ) -> Result<Arc<Self::Session>, Self::Error> {
447            unimplemented!()
448        }
449
450        fn connect(
451            &self,
452            _database: &str,
453            _user_name: &str,
454            _peer_addr: crate::net::AddressRef,
455        ) -> Result<Arc<Self::Session>, Self::Error> {
456            Ok(Arc::new(MockSession {}))
457        }
458
459        fn cancel_queries_in_session(&self, _session_id: SessionId) {
460            todo!()
461        }
462
463        fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
464            todo!()
465        }
466
467        fn end_session(&self, _session: &Self::Session) {}
468    }
469
470    impl Session for MockSession {
471        type Error = BoxedError;
472        type Portal = String;
473        type PreparedStatement = String;
474        type ValuesStream = BoxStream<'static, RowSetResult>;
475
476        async fn run_one_query(
477            self: Arc<Self>,
478            _stmt: Statement,
479            _format: types::Format,
480        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
481            Ok(PgResponse::builder(StatementType::SELECT)
482                .values(
483                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
484                        .boxed(),
485                    vec![
486                        // 1043 is the oid of varchar type.
487                        // -1 is the type len of varchar type.
488                        PgFieldDescriptor::new("".to_owned(), 1043, -1);
489                        1
490                    ],
491                )
492                .into())
493        }
494
495        async fn parse(
496            self: Arc<Self>,
497            _sql: Option<Statement>,
498            _params_types: Vec<Option<DataType>>,
499        ) -> Result<String, Self::Error> {
500            Ok(String::new())
501        }
502
503        fn bind(
504            self: Arc<Self>,
505            _prepare_statement: String,
506            _params: Vec<Option<Bytes>>,
507            _param_formats: Vec<types::Format>,
508            _result_formats: Vec<types::Format>,
509        ) -> Result<String, Self::Error> {
510            Ok(String::new())
511        }
512
513        async fn execute(
514            self: Arc<Self>,
515            _portal: String,
516        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, Self::Error> {
517            Ok(PgResponse::builder(StatementType::SELECT)
518                .values(
519                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
520                        .boxed(),
521                    vec![
522                    // 1043 is the oid of varchar type.
523                    // -1 is the type len of varchar type.
524                    PgFieldDescriptor::new("".to_owned(), 1043, -1);
525                    1
526                ],
527                )
528                .into())
529        }
530
531        fn describe_statement(
532            self: Arc<Self>,
533            _statement: String,
534        ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), Self::Error> {
535            Ok((
536                vec![],
537                vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
538            ))
539        }
540
541        fn describe_portal(
542            self: Arc<Self>,
543            _portal: String,
544        ) -> Result<Vec<PgFieldDescriptor>, Self::Error> {
545            Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
546        }
547
548        fn user_authenticator(&self) -> &UserAuthenticator {
549            &UserAuthenticator::None
550        }
551
552        fn id(&self) -> SessionId {
553            (0, 0)
554        }
555
556        fn get_config(&self, key: &str) -> Result<String, Self::Error> {
557            match key {
558                "timezone" => Ok("UTC".to_owned()),
559                _ => Err(format!("Unknown config key: {key}").into()),
560            }
561        }
562
563        fn set_config(&self, _key: &str, _value: String) -> Result<String, Self::Error> {
564            Ok("".to_owned())
565        }
566
567        async fn next_notice(self: &Arc<Self>) -> String {
568            std::future::pending().await
569        }
570
571        fn transaction_status(&self) -> TransactionStatus {
572            TransactionStatus::Idle
573        }
574
575        fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
576            let exec_context = Arc::new(ExecContext {
577                running_sql: sql,
578                last_instant: Instant::now(),
579                last_idle_instant: Default::default(),
580            });
581            ExecContextGuard::new(exec_context)
582        }
583
584        fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
585            Ok(())
586        }
587
588        fn user(&self) -> String {
589            "mock".to_owned()
590        }
591    }
592
593    async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
594        let bind_addr = bind_addr.into();
595        let pg_config = pg_config.into();
596
597        let session_mgr = MockSessionManager {};
598        tokio::spawn(async move {
599            pg_serve(
600                &bind_addr,
601                socket2::TcpKeepalive::new(),
602                Arc::new(session_mgr),
603                ConnectionContext {
604                    tls_config: None,
605                    redact_sql_option_keywords: None,
606                    message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
607                        .into(),
608                },
609                CancellationToken::new(), // dummy
610            )
611            .await
612        });
613        // wait for server to start
614        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
615
616        // Connect to the database.
617        let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
618
619        // The connection object performs the actual communication with the database,
620        // so spawn it off to run on its own.
621        tokio::spawn(async move {
622            if let Err(e) = connection.await {
623                eprintln!("connection error: {}", e);
624            }
625        });
626
627        let rows = client
628            .simple_query("SELECT ''")
629            .await
630            .expect("Error executing query");
631        // Row + CommandComplete
632        assert_eq!(rows.len(), 2);
633
634        let rows = client
635            .query("SELECT ''", &[])
636            .await
637            .expect("Error executing query");
638        assert_eq!(rows.len(), 1);
639    }
640
641    #[tokio::test]
642    async fn test_query_tcp() {
643        do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
644    }
645
646    #[cfg(not(madsim))]
647    #[tokio::test]
648    async fn test_query_unix() {
649        let port: i16 = 10000;
650        let dir = tempfile::TempDir::new().unwrap();
651        let sock = dir.path().join(format!(".s.PGSQL.{port}"));
652
653        do_test_query(
654            format!("unix:{}", sock.to_str().unwrap()),
655            format!("host={} port={}", dir.path().to_str().unwrap(), port),
656        )
657        .await;
658    }
659
660    mod jwt_validation_tests {
661        use std::collections::HashMap;
662        use std::time::{SystemTime, UNIX_EPOCH};
663
664        use base64::Engine;
665        use jsonwebtoken::{Algorithm, EncodingKey, Header};
666        use rsa::pkcs1::EncodeRsaPrivateKey;
667        use rsa::traits::PublicKeyParts;
668        use rsa::{RsaPrivateKey, RsaPublicKey};
669        use serde_json::json;
670
671        use crate::pg_server::{Jwk, Jwks, validate_jwt_with_jwks};
672
673        fn create_test_rsa_keys() -> (RsaPrivateKey, RsaPublicKey) {
674            let mut rng = rand::thread_rng();
675            let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate a key");
676            let public_key = RsaPublicKey::from(&private_key);
677            (private_key, public_key)
678        }
679
680        fn create_test_jwks(public_key: &RsaPublicKey, kid: &str, alg: Option<&str>) -> Jwks {
681            let n = base64::engine::general_purpose::URL_SAFE_NO_PAD
682                .encode(public_key.n().to_bytes_be());
683            let e = base64::engine::general_purpose::URL_SAFE_NO_PAD
684                .encode(public_key.e().to_bytes_be());
685
686            Jwks {
687                keys: vec![Jwk {
688                    kid: kid.to_owned(),
689                    alg: alg.map(ToOwned::to_owned),
690                    n,
691                    e,
692                }],
693            }
694        }
695
696        fn create_jwt_token(
697            private_key: &RsaPrivateKey,
698            kid: &str,
699            algorithm: Algorithm,
700            issuer: &str,
701            audience: Option<&str>,
702            exp: u64,
703            additional_claims: HashMap<String, serde_json::Value>,
704        ) -> String {
705            let mut header = Header::new(algorithm);
706            header.kid = Some(kid.to_owned());
707
708            let mut claims = json!({
709                "iss": issuer,
710                "exp": exp,
711            });
712
713            if let Some(aud) = audience {
714                claims["aud"] = json!(aud);
715            }
716
717            for (key, value) in additional_claims {
718                claims[key] = value;
719            }
720
721            let encoding_key = EncodingKey::from_rsa_pem(
722                private_key
723                    .to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
724                    .unwrap()
725                    .as_bytes(),
726            )
727            .unwrap();
728
729            jsonwebtoken::encode(&header, &claims, &encoding_key).unwrap()
730        }
731
732        fn get_future_timestamp() -> u64 {
733            SystemTime::now()
734                .duration_since(UNIX_EPOCH)
735                .unwrap()
736                .as_secs()
737                + 3600 // 1 hour from now
738        }
739
740        fn get_past_timestamp() -> u64 {
741            SystemTime::now()
742                .duration_since(UNIX_EPOCH)
743                .unwrap()
744                .as_secs()
745                - 3600 // 1 hour ago
746        }
747
748        #[test]
749        fn test_jwt_with_invalid_audience() {
750            let (private_key, public_key) = create_test_rsa_keys();
751            let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
752
753            let metadata = HashMap::new();
754
755            let jwt = create_jwt_token(
756                &private_key,
757                "test-kid",
758                Algorithm::RS256,
759                "https://test-issuer.com",
760                Some("urn:risingwave:cluster:wrong-cluster-id"),
761                get_future_timestamp(),
762                HashMap::new(),
763            );
764
765            let result = validate_jwt_with_jwks(
766                &jwt,
767                &jwks,
768                "https://test-issuer.com",
769                "test-cluster-id",
770                &metadata,
771            );
772
773            let error = result.unwrap_err();
774            assert!(error.to_string().contains("InvalidAudience"));
775        }
776
777        #[test]
778        fn test_jwt_with_missing_audience() {
779            let (private_key, public_key) = create_test_rsa_keys();
780            let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
781
782            let metadata = HashMap::new();
783
784            let jwt = create_jwt_token(
785                &private_key,
786                "test-kid",
787                Algorithm::RS256,
788                "https://test-issuer.com",
789                None, // No audience claim
790                get_future_timestamp(),
791                HashMap::new(),
792            );
793
794            let result = validate_jwt_with_jwks(
795                &jwt,
796                &jwks,
797                "https://test-issuer.com",
798                "test-cluster-id",
799                &metadata,
800            );
801
802            let error = result.unwrap_err();
803            assert!(error.to_string().contains("Missing required claim: aud"));
804        }
805
806        #[test]
807        fn test_jwt_with_invalid_issuer() {
808            let (private_key, public_key) = create_test_rsa_keys();
809            let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
810
811            let metadata = HashMap::new();
812
813            let jwt = create_jwt_token(
814                &private_key,
815                "test-kid",
816                Algorithm::RS256,
817                "https://wrong-issuer.com",
818                Some("urn:risingwave:cluster:test-cluster-id"),
819                get_future_timestamp(),
820                HashMap::new(),
821            );
822
823            let result = validate_jwt_with_jwks(
824                &jwt,
825                &jwks,
826                "https://test-issuer.com",
827                "test-cluster-id",
828                &metadata,
829            );
830
831            let error = result.unwrap_err();
832            assert!(error.to_string().contains("InvalidIssuer"));
833        }
834
835        #[test]
836        fn test_jwt_with_kid_not_found_in_jwks() {
837            let (private_key, public_key) = create_test_rsa_keys();
838            let jwks = create_test_jwks(&public_key, "different-kid", Some("RS256"));
839
840            let metadata = HashMap::new();
841
842            let jwt = create_jwt_token(
843                &private_key,
844                "missing-kid",
845                Algorithm::RS256,
846                "https://test-issuer.com",
847                Some("urn:risingwave:cluster:test-cluster-id"),
848                get_future_timestamp(),
849                HashMap::new(),
850            );
851
852            let result = validate_jwt_with_jwks(
853                &jwt,
854                &jwks,
855                "https://test-issuer.com",
856                "test-cluster-id",
857                &metadata,
858            );
859
860            let error = result.unwrap_err();
861            assert!(
862                error
863                    .to_string()
864                    .contains("No matching key found in JWKS for kid: 'missing-kid'")
865            );
866        }
867
868        #[test]
869        fn test_jwt_with_expired_token() {
870            let (private_key, public_key) = create_test_rsa_keys();
871            let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
872
873            let metadata = HashMap::new();
874
875            let jwt = create_jwt_token(
876                &private_key,
877                "test-kid",
878                Algorithm::RS256,
879                "https://test-issuer.com",
880                Some("urn:risingwave:cluster:test-cluster-id"),
881                get_past_timestamp(), // Expired token
882                HashMap::new(),
883            );
884
885            let result = validate_jwt_with_jwks(
886                &jwt,
887                &jwks,
888                "https://test-issuer.com",
889                "test-cluster-id",
890                &metadata,
891            );
892
893            let error = result.unwrap_err();
894            assert!(error.to_string().contains("ExpiredSignature"));
895        }
896
897        #[test]
898        fn test_jwt_with_invalid_signature() {
899            let (_, public_key) = create_test_rsa_keys();
900            let (wrong_private_key, _) = create_test_rsa_keys(); // Different key pair
901            let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
902
903            let metadata = HashMap::new();
904
905            // Sign with wrong private key
906            let jwt = create_jwt_token(
907                &wrong_private_key,
908                "test-kid",
909                Algorithm::RS256,
910                "https://test-issuer.com",
911                Some("urn:risingwave:cluster:test-cluster-id"),
912                get_future_timestamp(),
913                HashMap::new(),
914            );
915
916            let result = validate_jwt_with_jwks(
917                &jwt,
918                &jwks,
919                "https://test-issuer.com",
920                "test-cluster-id",
921                &metadata,
922            );
923
924            let error = result.unwrap_err();
925            assert!(error.to_string().contains("InvalidSignature"));
926        }
927
928        #[test]
929        fn test_metadata_validation_success() {
930            let (private_key, public_key) = create_test_rsa_keys();
931            let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
932
933            let mut metadata = HashMap::new();
934            metadata.insert("role".to_owned(), "admin".to_owned());
935            metadata.insert("department".to_owned(), "security".to_owned());
936
937            let mut claims = HashMap::new();
938            claims.insert("role".to_owned(), json!("admin"));
939            claims.insert("department".to_owned(), json!("security"));
940            claims.insert("extra_claim".to_owned(), json!("ignored")); // Extra claims are fine
941
942            let jwt = create_jwt_token(
943                &private_key,
944                "test-kid",
945                Algorithm::RS256,
946                "https://test-issuer.com",
947                Some("urn:risingwave:cluster:test-cluster-id"),
948                get_future_timestamp(),
949                claims,
950            );
951
952            let result = validate_jwt_with_jwks(
953                &jwt,
954                &jwks,
955                "https://test-issuer.com",
956                "test-cluster-id",
957                &metadata,
958            );
959
960            assert!(result.unwrap());
961        }
962
963        #[test]
964        fn test_metadata_validation_failure() {
965            let (private_key, public_key) = create_test_rsa_keys();
966            let jwks = create_test_jwks(&public_key, "test-kid", Some("RS256"));
967
968            let mut metadata = HashMap::new();
969            metadata.insert("role".to_owned(), "admin".to_owned());
970            metadata.insert("department".to_owned(), "security".to_owned());
971
972            let mut claims = HashMap::new();
973            claims.insert("role".to_owned(), json!("user")); // Wrong role
974            claims.insert("department".to_owned(), json!("security"));
975
976            let jwt = create_jwt_token(
977                &private_key,
978                "test-kid",
979                Algorithm::RS256,
980                "https://test-issuer.com",
981                Some("urn:risingwave:cluster:test-cluster-id"),
982                get_future_timestamp(),
983                claims,
984            );
985
986            let result = validate_jwt_with_jwks(
987                &jwt,
988                &jwks,
989                "https://test-issuer.com",
990                "test-cluster-id",
991                &metadata,
992            );
993
994            let error = result.unwrap_err();
995            assert_eq!(
996                error.to_string(),
997                "metadata in jwt does not match with metadata declared with user"
998            );
999        }
1000
1001        #[test]
1002        fn test_jwt_with_jwk_missing_alg_succeeds() {
1003            let (private_key, public_key) = create_test_rsa_keys();
1004            let jwks = create_test_jwks(&public_key, "test-kid", None);
1005
1006            let jwt = create_jwt_token(
1007                &private_key,
1008                "test-kid",
1009                Algorithm::RS256,
1010                "https://test-issuer.com",
1011                Some("urn:risingwave:cluster:test-cluster-id"),
1012                get_future_timestamp(),
1013                HashMap::new(),
1014            );
1015
1016            let result = validate_jwt_with_jwks(
1017                &jwt,
1018                &jwks,
1019                "https://test-issuer.com",
1020                "test-cluster-id",
1021                &HashMap::new(),
1022            );
1023
1024            assert!(result.unwrap());
1025        }
1026
1027        #[test]
1028        fn test_jwt_with_jwk_missing_alg_rejects_disallowed_header_alg() {
1029            let (_, public_key) = create_test_rsa_keys();
1030            let jwks = create_test_jwks(&public_key, "test-kid", None);
1031
1032            // Craft a token whose header claims HS256. The allow-list check
1033            // must reject it before the signature is ever verified — this is
1034            // the defence against the classic "alg=HS256 forged with the RSA
1035            // public key as the HMAC secret" confusion attack.
1036            let mut header = Header::new(Algorithm::HS256);
1037            header.kid = Some("test-kid".to_owned());
1038            let claims = json!({
1039                "iss": "https://test-issuer.com",
1040                "aud": "urn:risingwave:cluster:test-cluster-id",
1041                "exp": get_future_timestamp(),
1042            });
1043            let jwt = jsonwebtoken::encode(
1044                &header,
1045                &claims,
1046                &EncodingKey::from_secret(b"attacker-chosen"),
1047            )
1048            .unwrap();
1049
1050            let result = validate_jwt_with_jwks(
1051                &jwt,
1052                &jwks,
1053                "https://test-issuer.com",
1054                "test-cluster-id",
1055                &HashMap::new(),
1056            );
1057
1058            let error = result.unwrap_err();
1059            assert!(
1060                error.to_string().contains("is not allowed"),
1061                "unexpected error: {}",
1062                error
1063            );
1064        }
1065
1066        #[test]
1067        fn test_jwt_alg_mismatch_between_header_and_jwk() {
1068            let (private_key, public_key) = create_test_rsa_keys();
1069            // JWK pins RS384; token header declares RS256 — must be rejected.
1070            let jwks = create_test_jwks(&public_key, "test-kid", Some("RS384"));
1071
1072            let jwt = create_jwt_token(
1073                &private_key,
1074                "test-kid",
1075                Algorithm::RS256,
1076                "https://test-issuer.com",
1077                Some("urn:risingwave:cluster:test-cluster-id"),
1078                get_future_timestamp(),
1079                HashMap::new(),
1080            );
1081
1082            let result = validate_jwt_with_jwks(
1083                &jwt,
1084                &jwks,
1085                "https://test-issuer.com",
1086                "test-cluster-id",
1087                &HashMap::new(),
1088            );
1089
1090            let error = result.unwrap_err();
1091            assert_eq!(
1092                error.to_string(),
1093                "alg in jwt header does not match with alg in jwk"
1094            );
1095        }
1096    }
1097}