pgwire/
pg_server.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::types::DataType;
25use risingwave_common::util::runtime::BackgroundShutdownRuntime;
26use risingwave_common::util::tokio_util::sync::CancellationToken;
27use risingwave_jni_core::jvm_runtime::register_jvm_builder;
28use risingwave_sqlparser::ast::Statement;
29use serde::Deserialize;
30use thiserror_ext::AsReport;
31
32use crate::error::{PsqlError, PsqlResult};
33use crate::net::{AddressRef, Listener, TcpKeepalive};
34use crate::pg_field_descriptor::PgFieldDescriptor;
35use crate::pg_message::TransactionStatus;
36use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
37use crate::pg_response::{PgResponse, ValuesStream};
38use crate::types::Format;
39
40pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
41type ProcessId = i32;
42type SecretKey = i32;
43pub type SessionId = (ProcessId, SecretKey);
44
45/// The interface for a database system behind pgwire protocol.
46/// We can mock it for testing purpose.
47pub trait SessionManager: Send + Sync + 'static {
48    type Session: Session;
49
50    /// In the process of auto schema change, we need a dummy session to access
51    /// catalog information in frontend and build a replace plan for the table.
52    fn create_dummy_session(
53        &self,
54        database_id: u32,
55        user_id: u32,
56    ) -> Result<Arc<Self::Session>, BoxedError>;
57
58    fn connect(
59        &self,
60        database: &str,
61        user_name: &str,
62        peer_addr: AddressRef,
63    ) -> Result<Arc<Self::Session>, BoxedError>;
64
65    fn cancel_queries_in_session(&self, session_id: SessionId);
66
67    fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
68
69    fn end_session(&self, session: &Self::Session);
70
71    /// Run some cleanup tasks before the server shutdown.
72    fn shutdown(&self) -> impl Future<Output = ()> + Send {
73        async {}
74    }
75}
76
77/// A psql connection. Each connection binds with a database. Switching database will need to
78/// recreate another connection.
79pub trait Session: Send + Sync {
80    type ValuesStream: ValuesStream;
81    type PreparedStatement: Send + Clone + 'static;
82    type Portal: Send + Clone + std::fmt::Display + 'static;
83
84    /// The str sql can not use the unparse from AST: There is some problem when dealing with create
85    /// view, see <https://github.com/risingwavelabs/risingwave/issues/6801>.
86    fn run_one_query(
87        self: Arc<Self>,
88        stmt: Statement,
89        format: Format,
90    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
91
92    fn parse(
93        self: Arc<Self>,
94        sql: Option<Statement>,
95        params_types: Vec<Option<DataType>>,
96    ) -> impl Future<Output = Result<Self::PreparedStatement, BoxedError>> + Send;
97
98    /// Receive the next notice message to send to the client.
99    ///
100    /// This function should be cancellation-safe.
101    fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
102
103    fn bind(
104        self: Arc<Self>,
105        prepare_statement: Self::PreparedStatement,
106        params: Vec<Option<Bytes>>,
107        param_formats: Vec<Format>,
108        result_formats: Vec<Format>,
109    ) -> Result<Self::Portal, BoxedError>;
110
111    fn execute(
112        self: Arc<Self>,
113        portal: Self::Portal,
114    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
115
116    fn describe_statement(
117        self: Arc<Self>,
118        prepare_statement: Self::PreparedStatement,
119    ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError>;
120
121    fn describe_portal(
122        self: Arc<Self>,
123        portal: Self::Portal,
124    ) -> Result<Vec<PgFieldDescriptor>, BoxedError>;
125
126    fn user_authenticator(&self) -> &UserAuthenticator;
127
128    fn id(&self) -> SessionId;
129
130    fn get_config(&self, key: &str) -> Result<String, BoxedError>;
131
132    fn set_config(&self, key: &str, value: String) -> Result<String, BoxedError>;
133
134    fn transaction_status(&self) -> TransactionStatus;
135
136    fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
137
138    fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
139}
140
141/// Each session could run different SQLs multiple times.
142/// `ExecContext` represents the lifetime of a running SQL in the current session.
143pub struct ExecContext {
144    pub running_sql: Arc<str>,
145    /// The instant of the running sql
146    pub last_instant: Instant,
147    /// A reference used to update when `ExecContext` is dropped
148    pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
149}
150
151/// `ExecContextGuard` holds a `Arc` pointer. Once `ExecContextGuard` is dropped,
152/// the inner `Arc<ExecContext>` should not be referred anymore, so that its `Weak` reference (used in `SessionImpl`) will be the same lifecycle of the running sql execution context.
153pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
154
155impl ExecContextGuard {
156    pub fn new(exec_context: Arc<ExecContext>) -> Self {
157        Self(exec_context)
158    }
159}
160
161impl Drop for ExecContext {
162    fn drop(&mut self) {
163        *self.last_idle_instant.lock() = Some(Instant::now());
164    }
165}
166
167#[derive(Debug, Clone)]
168pub enum UserAuthenticator {
169    // No need to authenticate.
170    None,
171    // raw password in clear-text form.
172    ClearText(Vec<u8>),
173    // password encrypted with random salt.
174    Md5WithSalt {
175        encrypted_password: Vec<u8>,
176        salt: [u8; 4],
177    },
178    OAuth(HashMap<String, String>),
179}
180
181/// A JWK Set is a JSON object that represents a set of JWKs.
182/// The JSON object MUST have a "keys" member, with its value being an array of JWKs.
183/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-5> for more details.
184#[derive(Debug, Deserialize)]
185struct Jwks {
186    keys: Vec<Jwk>,
187}
188
189/// A JSON Web Key (JWK) is a JSON object that represents a cryptographic key.
190/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-4> for more details.
191#[derive(Debug, Deserialize)]
192struct Jwk {
193    kid: String, // Key ID
194    alg: String, // Algorithm
195    n: String,   // Modulus
196    e: String,   // Exponent
197}
198
199async fn validate_jwt(
200    jwt: &str,
201    jwks_url: &str,
202    issuer: &str,
203    metadata: &HashMap<String, String>,
204) -> Result<bool, BoxedError> {
205    let header = decode_header(jwt)?;
206    let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
207
208    // 1. Retrieve the kid from the header to find the right JWK in the JWK Set.
209    let kid = header.kid.ok_or("kid not found in jwt header")?;
210    let jwk = jwks
211        .keys
212        .into_iter()
213        .find(|k| k.kid == kid)
214        .ok_or("kid not found in jwks")?;
215
216    // 2. Check if the algorithms are matched.
217    if Algorithm::from_str(&jwk.alg)? != header.alg {
218        return Err("alg in jwt header does not match with alg in jwk".into());
219    }
220
221    // 3. Decode the JWT and validate the claims.
222    let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
223    let mut validation = Validation::new(header.alg);
224    validation.set_issuer(&[issuer]);
225    validation.set_required_spec_claims(&["exp", "iss"]);
226    let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
227
228    // 4. Check if the metadata in the token matches.
229    if !metadata.iter().all(
230        |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
231    ) {
232        return Err("metadata in jwt does not match with metadata declared with user".into());
233    }
234    Ok(true)
235}
236
237impl UserAuthenticator {
238    pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
239        let success = match self {
240            UserAuthenticator::None => true,
241            UserAuthenticator::ClearText(text) => password == text,
242            UserAuthenticator::Md5WithSalt {
243                encrypted_password, ..
244            } => encrypted_password == password,
245            UserAuthenticator::OAuth(metadata) => {
246                let mut metadata = metadata.clone();
247                let jwks_url = metadata.remove("jwks_url").unwrap();
248                let issuer = metadata.remove("issuer").unwrap();
249                validate_jwt(
250                    &String::from_utf8_lossy(password),
251                    &jwks_url,
252                    &issuer,
253                    &metadata,
254                )
255                .await
256                .map_err(PsqlError::StartupError)?
257            }
258        };
259        if !success {
260            return Err(PsqlError::PasswordError);
261        }
262        Ok(())
263    }
264}
265
266/// Binds a Tcp or Unix listener at `addr`. Spawn a coroutine to serve every new connection.
267///
268/// Returns when the `shutdown` token is triggered.
269pub async fn pg_serve(
270    addr: &str,
271    tcp_keepalive: TcpKeepalive,
272    session_mgr: Arc<impl SessionManager>,
273    context: ConnectionContext,
274    shutdown: CancellationToken,
275) -> Result<(), BoxedError> {
276    let listener = Listener::bind(addr).await?;
277    tracing::info!(addr, "server started");
278    register_jvm_builder();
279    let acceptor_runtime = BackgroundShutdownRuntime::from({
280        let mut builder = tokio::runtime::Builder::new_multi_thread();
281        builder.worker_threads(1);
282        builder
283            .thread_name("rw-acceptor")
284            .enable_all()
285            .build()
286            .unwrap()
287    });
288
289    #[cfg(not(madsim))]
290    let worker_runtime = tokio::runtime::Handle::current();
291    #[cfg(madsim)]
292    let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
293    let session_mgr_clone = session_mgr.clone();
294    let f = async move {
295        loop {
296            let conn_ret = listener.accept(&tcp_keepalive).await;
297            match conn_ret {
298                Ok((stream, peer_addr)) => {
299                    tracing::info!(%peer_addr, "accept connection");
300                    worker_runtime.spawn(handle_connection(
301                        stream,
302                        session_mgr_clone.clone(),
303                        Arc::new(peer_addr),
304                        context.clone(),
305                    ));
306                }
307
308                Err(e) => {
309                    tracing::error!(error = %e.as_report(), "failed to accept connection",);
310                }
311            }
312        }
313    };
314    acceptor_runtime.spawn(f);
315
316    // Wait for the shutdown signal.
317    shutdown.cancelled().await;
318
319    // Stop accepting new connections.
320    drop(acceptor_runtime);
321    // Shutdown session manager, typically close all existing sessions.
322    session_mgr.shutdown().await;
323
324    Ok(())
325}
326
327pub async fn handle_connection<S, SM>(
328    stream: S,
329    session_mgr: Arc<SM>,
330    peer_addr: AddressRef,
331    context: ConnectionContext,
332) where
333    S: PgByteStream,
334    SM: SessionManager,
335{
336    PgProtocol::new(stream, session_mgr, peer_addr, context)
337        .run()
338        .await;
339}
340#[cfg(test)]
341mod tests {
342    use std::error::Error;
343    use std::sync::Arc;
344    use std::time::Instant;
345
346    use bytes::Bytes;
347    use futures::StreamExt;
348    use futures::stream::BoxStream;
349    use risingwave_common::types::DataType;
350    use risingwave_common::util::tokio_util::sync::CancellationToken;
351    use risingwave_sqlparser::ast::Statement;
352    use tokio_postgres::NoTls;
353
354    use crate::error::PsqlResult;
355    use crate::memory_manager::MessageMemoryManager;
356    use crate::pg_field_descriptor::PgFieldDescriptor;
357    use crate::pg_message::TransactionStatus;
358    use crate::pg_protocol::ConnectionContext;
359    use crate::pg_response::{PgResponse, RowSetResult, StatementType};
360    use crate::pg_server::{
361        BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
362        UserAuthenticator, pg_serve,
363    };
364    use crate::types;
365    use crate::types::Row;
366
367    struct MockSessionManager {}
368    struct MockSession {}
369
370    impl SessionManager for MockSessionManager {
371        type Session = MockSession;
372
373        fn create_dummy_session(
374            &self,
375            _database_id: u32,
376            _user_name: u32,
377        ) -> Result<Arc<Self::Session>, BoxedError> {
378            unimplemented!()
379        }
380
381        fn connect(
382            &self,
383            _database: &str,
384            _user_name: &str,
385            _peer_addr: crate::net::AddressRef,
386        ) -> Result<Arc<Self::Session>, Box<dyn Error + Send + Sync>> {
387            Ok(Arc::new(MockSession {}))
388        }
389
390        fn cancel_queries_in_session(&self, _session_id: SessionId) {
391            todo!()
392        }
393
394        fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
395            todo!()
396        }
397
398        fn end_session(&self, _session: &Self::Session) {}
399    }
400
401    impl Session for MockSession {
402        type Portal = String;
403        type PreparedStatement = String;
404        type ValuesStream = BoxStream<'static, RowSetResult>;
405
406        async fn run_one_query(
407            self: Arc<Self>,
408            _stmt: Statement,
409            _format: types::Format,
410        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
411            Ok(PgResponse::builder(StatementType::SELECT)
412                .values(
413                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
414                        .boxed(),
415                    vec![
416                        // 1043 is the oid of varchar type.
417                        // -1 is the type len of varchar type.
418                        PgFieldDescriptor::new("".to_owned(), 1043, -1);
419                        1
420                    ],
421                )
422                .into())
423        }
424
425        async fn parse(
426            self: Arc<Self>,
427            _sql: Option<Statement>,
428            _params_types: Vec<Option<DataType>>,
429        ) -> Result<String, BoxedError> {
430            Ok(String::new())
431        }
432
433        fn bind(
434            self: Arc<Self>,
435            _prepare_statement: String,
436            _params: Vec<Option<Bytes>>,
437            _param_formats: Vec<types::Format>,
438            _result_formats: Vec<types::Format>,
439        ) -> Result<String, BoxedError> {
440            Ok(String::new())
441        }
442
443        async fn execute(
444            self: Arc<Self>,
445            _portal: String,
446        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
447            Ok(PgResponse::builder(StatementType::SELECT)
448                .values(
449                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
450                        .boxed(),
451                    vec![
452                    // 1043 is the oid of varchar type.
453                    // -1 is the type len of varchar type.
454                    PgFieldDescriptor::new("".to_owned(), 1043, -1);
455                    1
456                ],
457                )
458                .into())
459        }
460
461        fn describe_statement(
462            self: Arc<Self>,
463            _statement: String,
464        ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
465            Ok((
466                vec![],
467                vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
468            ))
469        }
470
471        fn describe_portal(
472            self: Arc<Self>,
473            _portal: String,
474        ) -> Result<Vec<PgFieldDescriptor>, BoxedError> {
475            Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
476        }
477
478        fn user_authenticator(&self) -> &UserAuthenticator {
479            &UserAuthenticator::None
480        }
481
482        fn id(&self) -> SessionId {
483            (0, 0)
484        }
485
486        fn get_config(&self, key: &str) -> Result<String, BoxedError> {
487            match key {
488                "timezone" => Ok("UTC".to_owned()),
489                _ => Err(format!("Unknown config key: {key}").into()),
490            }
491        }
492
493        fn set_config(&self, _key: &str, _value: String) -> Result<String, BoxedError> {
494            Ok("".to_owned())
495        }
496
497        async fn next_notice(self: &Arc<Self>) -> String {
498            std::future::pending().await
499        }
500
501        fn transaction_status(&self) -> TransactionStatus {
502            TransactionStatus::Idle
503        }
504
505        fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
506            let exec_context = Arc::new(ExecContext {
507                running_sql: sql,
508                last_instant: Instant::now(),
509                last_idle_instant: Default::default(),
510            });
511            ExecContextGuard::new(exec_context)
512        }
513
514        fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
515            Ok(())
516        }
517    }
518
519    async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
520        let bind_addr = bind_addr.into();
521        let pg_config = pg_config.into();
522
523        let session_mgr = MockSessionManager {};
524        tokio::spawn(async move {
525            pg_serve(
526                &bind_addr,
527                socket2::TcpKeepalive::new(),
528                Arc::new(session_mgr),
529                ConnectionContext {
530                    tls_config: None,
531                    redact_sql_option_keywords: None,
532                    message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
533                        .into(),
534                },
535                CancellationToken::new(), // dummy
536            )
537            .await
538        });
539        // wait for server to start
540        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
541
542        // Connect to the database.
543        let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
544
545        // The connection object performs the actual communication with the database,
546        // so spawn it off to run on its own.
547        tokio::spawn(async move {
548            if let Err(e) = connection.await {
549                eprintln!("connection error: {}", e);
550            }
551        });
552
553        let rows = client
554            .simple_query("SELECT ''")
555            .await
556            .expect("Error executing query");
557        // Row + CommandComplete
558        assert_eq!(rows.len(), 2);
559
560        let rows = client
561            .query("SELECT ''", &[])
562            .await
563            .expect("Error executing query");
564        assert_eq!(rows.len(), 1);
565    }
566
567    #[tokio::test]
568    async fn test_query_tcp() {
569        do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
570    }
571
572    #[cfg(not(madsim))]
573    #[tokio::test]
574    async fn test_query_unix() {
575        let port: i16 = 10000;
576        let dir = tempfile::TempDir::new().unwrap();
577        let sock = dir.path().join(format!(".s.PGSQL.{port}"));
578
579        do_test_query(
580            format!("unix:{}", sock.to_str().unwrap()),
581            format!("host={} port={}", dir.path().to_str().unwrap(), port),
582        )
583        .await;
584    }
585}