pgwire/
pg_server.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::types::DataType;
25use risingwave_common::util::runtime::BackgroundShutdownRuntime;
26use risingwave_common::util::tokio_util::sync::CancellationToken;
27use risingwave_sqlparser::ast::Statement;
28use serde::Deserialize;
29use thiserror_ext::AsReport;
30
31use crate::error::{PsqlError, PsqlResult};
32use crate::net::{AddressRef, Listener, TcpKeepalive};
33use crate::pg_field_descriptor::PgFieldDescriptor;
34use crate::pg_message::TransactionStatus;
35use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
36use crate::pg_response::{PgResponse, ValuesStream};
37use crate::types::Format;
38
39pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
40type ProcessId = i32;
41type SecretKey = i32;
42pub type SessionId = (ProcessId, SecretKey);
43
44/// The interface for a database system behind pgwire protocol.
45/// We can mock it for testing purpose.
46pub trait SessionManager: Send + Sync + 'static {
47    type Session: Session;
48
49    /// In the process of auto schema change, we need a dummy session to access
50    /// catalog information in frontend and build a replace plan for the table.
51    fn create_dummy_session(
52        &self,
53        database_id: u32,
54        user_id: u32,
55    ) -> Result<Arc<Self::Session>, BoxedError>;
56
57    fn connect(
58        &self,
59        database: &str,
60        user_name: &str,
61        peer_addr: AddressRef,
62    ) -> Result<Arc<Self::Session>, BoxedError>;
63
64    fn cancel_queries_in_session(&self, session_id: SessionId);
65
66    fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
67
68    fn end_session(&self, session: &Self::Session);
69
70    /// Run some cleanup tasks before the server shutdown.
71    fn shutdown(&self) -> impl Future<Output = ()> + Send {
72        async {}
73    }
74}
75
76/// A psql connection. Each connection binds with a database. Switching database will need to
77/// recreate another connection.
78pub trait Session: Send + Sync {
79    type ValuesStream: ValuesStream;
80    type PreparedStatement: Send + Clone + 'static;
81    type Portal: Send + Clone + std::fmt::Display + 'static;
82
83    /// The str sql can not use the unparse from AST: There is some problem when dealing with create
84    /// view, see <https://github.com/risingwavelabs/risingwave/issues/6801>.
85    fn run_one_query(
86        self: Arc<Self>,
87        stmt: Statement,
88        format: Format,
89    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
90
91    fn parse(
92        self: Arc<Self>,
93        sql: Option<Statement>,
94        params_types: Vec<Option<DataType>>,
95    ) -> impl Future<Output = Result<Self::PreparedStatement, BoxedError>> + Send;
96
97    /// Receive the next notice message to send to the client.
98    ///
99    /// This function should be cancellation-safe.
100    fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
101
102    fn bind(
103        self: Arc<Self>,
104        prepare_statement: Self::PreparedStatement,
105        params: Vec<Option<Bytes>>,
106        param_formats: Vec<Format>,
107        result_formats: Vec<Format>,
108    ) -> Result<Self::Portal, BoxedError>;
109
110    fn execute(
111        self: Arc<Self>,
112        portal: Self::Portal,
113    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
114
115    fn describe_statement(
116        self: Arc<Self>,
117        prepare_statement: Self::PreparedStatement,
118    ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError>;
119
120    fn describe_portal(
121        self: Arc<Self>,
122        portal: Self::Portal,
123    ) -> Result<Vec<PgFieldDescriptor>, BoxedError>;
124
125    fn user_authenticator(&self) -> &UserAuthenticator;
126
127    fn id(&self) -> SessionId;
128
129    fn set_config(&self, key: &str, value: String) -> Result<String, BoxedError>;
130
131    fn transaction_status(&self) -> TransactionStatus;
132
133    fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
134
135    fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
136}
137
138/// Each session could run different SQLs multiple times.
139/// `ExecContext` represents the lifetime of a running SQL in the current session.
140pub struct ExecContext {
141    pub running_sql: Arc<str>,
142    /// The instant of the running sql
143    pub last_instant: Instant,
144    /// A reference used to update when `ExecContext` is dropped
145    pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
146}
147
148/// `ExecContextGuard` holds a `Arc` pointer. Once `ExecContextGuard` is dropped,
149/// the inner `Arc<ExecContext>` should not be referred anymore, so that its `Weak` reference (used in `SessionImpl`) will be the same lifecycle of the running sql execution context.
150pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
151
152impl ExecContextGuard {
153    pub fn new(exec_context: Arc<ExecContext>) -> Self {
154        Self(exec_context)
155    }
156}
157
158impl Drop for ExecContext {
159    fn drop(&mut self) {
160        *self.last_idle_instant.lock() = Some(Instant::now());
161    }
162}
163
164#[derive(Debug, Clone)]
165pub enum UserAuthenticator {
166    // No need to authenticate.
167    None,
168    // raw password in clear-text form.
169    ClearText(Vec<u8>),
170    // password encrypted with random salt.
171    Md5WithSalt {
172        encrypted_password: Vec<u8>,
173        salt: [u8; 4],
174    },
175    OAuth(HashMap<String, String>),
176}
177
178/// A JWK Set is a JSON object that represents a set of JWKs.
179/// The JSON object MUST have a "keys" member, with its value being an array of JWKs.
180/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-5> for more details.
181#[derive(Debug, Deserialize)]
182struct Jwks {
183    keys: Vec<Jwk>,
184}
185
186/// A JSON Web Key (JWK) is a JSON object that represents a cryptographic key.
187/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-4> for more details.
188#[derive(Debug, Deserialize)]
189struct Jwk {
190    kid: String, // Key ID
191    alg: String, // Algorithm
192    n: String,   // Modulus
193    e: String,   // Exponent
194}
195
196async fn validate_jwt(
197    jwt: &str,
198    jwks_url: &str,
199    issuer: &str,
200    metadata: &HashMap<String, String>,
201) -> Result<bool, BoxedError> {
202    let header = decode_header(jwt)?;
203    let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
204
205    // 1. Retrieve the kid from the header to find the right JWK in the JWK Set.
206    let kid = header.kid.ok_or("kid not found in jwt header")?;
207    let jwk = jwks
208        .keys
209        .into_iter()
210        .find(|k| k.kid == kid)
211        .ok_or("kid not found in jwks")?;
212
213    // 2. Check if the algorithms are matched.
214    if Algorithm::from_str(&jwk.alg)? != header.alg {
215        return Err("alg in jwt header does not match with alg in jwk".into());
216    }
217
218    // 3. Decode the JWT and validate the claims.
219    let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
220    let mut validation = Validation::new(header.alg);
221    validation.set_issuer(&[issuer]);
222    validation.set_required_spec_claims(&["exp", "iss"]);
223    let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
224
225    // 4. Check if the metadata in the token matches.
226    if !metadata.iter().all(
227        |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
228    ) {
229        return Err("metadata in jwt does not match with metadata declared with user".into());
230    }
231    Ok(true)
232}
233
234impl UserAuthenticator {
235    pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
236        let success = match self {
237            UserAuthenticator::None => true,
238            UserAuthenticator::ClearText(text) => password == text,
239            UserAuthenticator::Md5WithSalt {
240                encrypted_password, ..
241            } => encrypted_password == password,
242            UserAuthenticator::OAuth(metadata) => {
243                let mut metadata = metadata.clone();
244                let jwks_url = metadata.remove("jwks_url").unwrap();
245                let issuer = metadata.remove("issuer").unwrap();
246                validate_jwt(
247                    &String::from_utf8_lossy(password),
248                    &jwks_url,
249                    &issuer,
250                    &metadata,
251                )
252                .await
253                .map_err(PsqlError::StartupError)?
254            }
255        };
256        if !success {
257            return Err(PsqlError::PasswordError);
258        }
259        Ok(())
260    }
261}
262
263/// Binds a Tcp or Unix listener at `addr`. Spawn a coroutine to serve every new connection.
264///
265/// Returns when the `shutdown` token is triggered.
266pub async fn pg_serve(
267    addr: &str,
268    tcp_keepalive: TcpKeepalive,
269    session_mgr: Arc<impl SessionManager>,
270    context: ConnectionContext,
271    shutdown: CancellationToken,
272) -> Result<(), BoxedError> {
273    let listener = Listener::bind(addr).await?;
274    tracing::info!(addr, "server started");
275
276    let acceptor_runtime = BackgroundShutdownRuntime::from({
277        let mut builder = tokio::runtime::Builder::new_multi_thread();
278        builder.worker_threads(1);
279        builder
280            .thread_name("rw-acceptor")
281            .enable_all()
282            .build()
283            .unwrap()
284    });
285
286    #[cfg(not(madsim))]
287    let worker_runtime = tokio::runtime::Handle::current();
288    #[cfg(madsim)]
289    let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
290    let session_mgr_clone = session_mgr.clone();
291    let f = async move {
292        loop {
293            let conn_ret = listener.accept(&tcp_keepalive).await;
294            match conn_ret {
295                Ok((stream, peer_addr)) => {
296                    tracing::info!(%peer_addr, "accept connection");
297                    worker_runtime.spawn(handle_connection(
298                        stream,
299                        session_mgr_clone.clone(),
300                        Arc::new(peer_addr),
301                        context.clone(),
302                    ));
303                }
304
305                Err(e) => {
306                    tracing::error!(error = %e.as_report(), "failed to accept connection",);
307                }
308            }
309        }
310    };
311    acceptor_runtime.spawn(f);
312
313    // Wait for the shutdown signal.
314    shutdown.cancelled().await;
315
316    // Stop accepting new connections.
317    drop(acceptor_runtime);
318    // Shutdown session manager, typically close all existing sessions.
319    session_mgr.shutdown().await;
320
321    Ok(())
322}
323
324pub async fn handle_connection<S, SM>(
325    stream: S,
326    session_mgr: Arc<SM>,
327    peer_addr: AddressRef,
328    context: ConnectionContext,
329) where
330    S: PgByteStream,
331    SM: SessionManager,
332{
333    PgProtocol::new(stream, session_mgr, peer_addr, context)
334        .run()
335        .await;
336}
337#[cfg(test)]
338mod tests {
339    use std::error::Error;
340    use std::sync::Arc;
341    use std::time::Instant;
342
343    use bytes::Bytes;
344    use futures::StreamExt;
345    use futures::stream::BoxStream;
346    use risingwave_common::types::DataType;
347    use risingwave_common::util::tokio_util::sync::CancellationToken;
348    use risingwave_sqlparser::ast::Statement;
349    use tokio_postgres::NoTls;
350
351    use crate::error::PsqlResult;
352    use crate::memory_manager::MessageMemoryManager;
353    use crate::pg_field_descriptor::PgFieldDescriptor;
354    use crate::pg_message::TransactionStatus;
355    use crate::pg_protocol::ConnectionContext;
356    use crate::pg_response::{PgResponse, RowSetResult, StatementType};
357    use crate::pg_server::{
358        BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
359        UserAuthenticator, pg_serve,
360    };
361    use crate::types;
362    use crate::types::Row;
363
364    struct MockSessionManager {}
365    struct MockSession {}
366
367    impl SessionManager for MockSessionManager {
368        type Session = MockSession;
369
370        fn create_dummy_session(
371            &self,
372            _database_id: u32,
373            _user_name: u32,
374        ) -> Result<Arc<Self::Session>, BoxedError> {
375            unimplemented!()
376        }
377
378        fn connect(
379            &self,
380            _database: &str,
381            _user_name: &str,
382            _peer_addr: crate::net::AddressRef,
383        ) -> Result<Arc<Self::Session>, Box<dyn Error + Send + Sync>> {
384            Ok(Arc::new(MockSession {}))
385        }
386
387        fn cancel_queries_in_session(&self, _session_id: SessionId) {
388            todo!()
389        }
390
391        fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
392            todo!()
393        }
394
395        fn end_session(&self, _session: &Self::Session) {}
396    }
397
398    impl Session for MockSession {
399        type Portal = String;
400        type PreparedStatement = String;
401        type ValuesStream = BoxStream<'static, RowSetResult>;
402
403        async fn run_one_query(
404            self: Arc<Self>,
405            _stmt: Statement,
406            _format: types::Format,
407        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
408            Ok(PgResponse::builder(StatementType::SELECT)
409                .values(
410                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
411                        .boxed(),
412                    vec![
413                        // 1043 is the oid of varchar type.
414                        // -1 is the type len of varchar type.
415                        PgFieldDescriptor::new("".to_owned(), 1043, -1);
416                        1
417                    ],
418                )
419                .into())
420        }
421
422        async fn parse(
423            self: Arc<Self>,
424            _sql: Option<Statement>,
425            _params_types: Vec<Option<DataType>>,
426        ) -> Result<String, BoxedError> {
427            Ok(String::new())
428        }
429
430        fn bind(
431            self: Arc<Self>,
432            _prepare_statement: String,
433            _params: Vec<Option<Bytes>>,
434            _param_formats: Vec<types::Format>,
435            _result_formats: Vec<types::Format>,
436        ) -> Result<String, BoxedError> {
437            Ok(String::new())
438        }
439
440        async fn execute(
441            self: Arc<Self>,
442            _portal: String,
443        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
444            Ok(PgResponse::builder(StatementType::SELECT)
445                .values(
446                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
447                        .boxed(),
448                    vec![
449                    // 1043 is the oid of varchar type.
450                    // -1 is the type len of varchar type.
451                    PgFieldDescriptor::new("".to_owned(), 1043, -1);
452                    1
453                ],
454                )
455                .into())
456        }
457
458        fn describe_statement(
459            self: Arc<Self>,
460            _statement: String,
461        ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
462            Ok((
463                vec![],
464                vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
465            ))
466        }
467
468        fn describe_portal(
469            self: Arc<Self>,
470            _portal: String,
471        ) -> Result<Vec<PgFieldDescriptor>, BoxedError> {
472            Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
473        }
474
475        fn user_authenticator(&self) -> &UserAuthenticator {
476            &UserAuthenticator::None
477        }
478
479        fn id(&self) -> SessionId {
480            (0, 0)
481        }
482
483        fn set_config(&self, _key: &str, _value: String) -> Result<String, BoxedError> {
484            Ok("".to_owned())
485        }
486
487        async fn next_notice(self: &Arc<Self>) -> String {
488            std::future::pending().await
489        }
490
491        fn transaction_status(&self) -> TransactionStatus {
492            TransactionStatus::Idle
493        }
494
495        fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
496            let exec_context = Arc::new(ExecContext {
497                running_sql: sql,
498                last_instant: Instant::now(),
499                last_idle_instant: Default::default(),
500            });
501            ExecContextGuard::new(exec_context)
502        }
503
504        fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
505            Ok(())
506        }
507    }
508
509    async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
510        let bind_addr = bind_addr.into();
511        let pg_config = pg_config.into();
512
513        let session_mgr = MockSessionManager {};
514        tokio::spawn(async move {
515            pg_serve(
516                &bind_addr,
517                socket2::TcpKeepalive::new(),
518                Arc::new(session_mgr),
519                ConnectionContext {
520                    tls_config: None,
521                    redact_sql_option_keywords: None,
522                    message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
523                        .into(),
524                },
525                CancellationToken::new(), // dummy
526            )
527            .await
528        });
529        // wait for server to start
530        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
531
532        // Connect to the database.
533        let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
534
535        // The connection object performs the actual communication with the database,
536        // so spawn it off to run on its own.
537        tokio::spawn(async move {
538            if let Err(e) = connection.await {
539                eprintln!("connection error: {}", e);
540            }
541        });
542
543        let rows = client
544            .simple_query("SELECT ''")
545            .await
546            .expect("Error executing query");
547        // Row + CommandComplete
548        assert_eq!(rows.len(), 2);
549
550        let rows = client
551            .query("SELECT ''", &[])
552            .await
553            .expect("Error executing query");
554        assert_eq!(rows.len(), 1);
555    }
556
557    #[tokio::test]
558    async fn test_query_tcp() {
559        do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
560    }
561
562    #[cfg(not(madsim))]
563    #[tokio::test]
564    async fn test_query_unix() {
565        let port: i16 = 10000;
566        let dir = tempfile::TempDir::new().unwrap();
567        let sock = dir.path().join(format!(".s.PGSQL.{port}"));
568
569        do_test_query(
570            format!("unix:{}", sock.to_str().unwrap()),
571            format!("host={} port={}", dir.path().to_str().unwrap(), port),
572        )
573        .await;
574    }
575}