pgwire/
pg_server.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::types::DataType;
25use risingwave_common::util::runtime::BackgroundShutdownRuntime;
26use risingwave_common::util::tokio_util::sync::CancellationToken;
27use risingwave_sqlparser::ast::Statement;
28use serde::Deserialize;
29use thiserror_ext::AsReport;
30
31use crate::error::{PsqlError, PsqlResult};
32use crate::net::{AddressRef, Listener, TcpKeepalive};
33use crate::pg_field_descriptor::PgFieldDescriptor;
34use crate::pg_message::TransactionStatus;
35use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
36use crate::pg_response::{PgResponse, ValuesStream};
37use crate::types::Format;
38
39pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
40type ProcessId = i32;
41type SecretKey = i32;
42pub type SessionId = (ProcessId, SecretKey);
43
44/// The interface for a database system behind pgwire protocol.
45/// We can mock it for testing purpose.
46pub trait SessionManager: Send + Sync + 'static {
47    type Session: Session;
48
49    /// In the process of auto schema change, we need a dummy session to access
50    /// catalog information in frontend and build a replace plan for the table.
51    fn create_dummy_session(
52        &self,
53        database_id: u32,
54        user_id: u32,
55    ) -> Result<Arc<Self::Session>, BoxedError>;
56
57    fn connect(
58        &self,
59        database: &str,
60        user_name: &str,
61        peer_addr: AddressRef,
62    ) -> Result<Arc<Self::Session>, BoxedError>;
63
64    fn cancel_queries_in_session(&self, session_id: SessionId);
65
66    fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
67
68    fn end_session(&self, session: &Self::Session);
69
70    /// Run some cleanup tasks before the server shutdown.
71    fn shutdown(&self) -> impl Future<Output = ()> + Send {
72        async {}
73    }
74}
75
76/// A psql connection. Each connection binds with a database. Switching database will need to
77/// recreate another connection.
78pub trait Session: Send + Sync {
79    type ValuesStream: ValuesStream;
80    type PreparedStatement: Send + Clone + 'static;
81    type Portal: Send + Clone + std::fmt::Display + 'static;
82
83    /// The str sql can not use the unparse from AST: There is some problem when dealing with create
84    /// view, see <https://github.com/risingwavelabs/risingwave/issues/6801>.
85    fn run_one_query(
86        self: Arc<Self>,
87        stmt: Statement,
88        format: Format,
89    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
90
91    fn parse(
92        self: Arc<Self>,
93        sql: Option<Statement>,
94        params_types: Vec<Option<DataType>>,
95    ) -> impl Future<Output = Result<Self::PreparedStatement, BoxedError>> + Send;
96
97    /// Receive the next notice message to send to the client.
98    ///
99    /// This function should be cancellation-safe.
100    fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
101
102    fn bind(
103        self: Arc<Self>,
104        prepare_statement: Self::PreparedStatement,
105        params: Vec<Option<Bytes>>,
106        param_formats: Vec<Format>,
107        result_formats: Vec<Format>,
108    ) -> Result<Self::Portal, BoxedError>;
109
110    fn execute(
111        self: Arc<Self>,
112        portal: Self::Portal,
113    ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
114
115    fn describe_statement(
116        self: Arc<Self>,
117        prepare_statement: Self::PreparedStatement,
118    ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError>;
119
120    fn describe_portal(
121        self: Arc<Self>,
122        portal: Self::Portal,
123    ) -> Result<Vec<PgFieldDescriptor>, BoxedError>;
124
125    fn user_authenticator(&self) -> &UserAuthenticator;
126
127    fn id(&self) -> SessionId;
128
129    fn get_config(&self, key: &str) -> Result<String, BoxedError>;
130
131    fn set_config(&self, key: &str, value: String) -> Result<String, BoxedError>;
132
133    fn transaction_status(&self) -> TransactionStatus;
134
135    fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
136
137    fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
138}
139
140/// Each session could run different SQLs multiple times.
141/// `ExecContext` represents the lifetime of a running SQL in the current session.
142pub struct ExecContext {
143    pub running_sql: Arc<str>,
144    /// The instant of the running sql
145    pub last_instant: Instant,
146    /// A reference used to update when `ExecContext` is dropped
147    pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
148}
149
150/// `ExecContextGuard` holds a `Arc` pointer. Once `ExecContextGuard` is dropped,
151/// the inner `Arc<ExecContext>` should not be referred anymore, so that its `Weak` reference (used in `SessionImpl`) will be the same lifecycle of the running sql execution context.
152pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
153
154impl ExecContextGuard {
155    pub fn new(exec_context: Arc<ExecContext>) -> Self {
156        Self(exec_context)
157    }
158}
159
160impl Drop for ExecContext {
161    fn drop(&mut self) {
162        *self.last_idle_instant.lock() = Some(Instant::now());
163    }
164}
165
166#[derive(Debug, Clone)]
167pub enum UserAuthenticator {
168    // No need to authenticate.
169    None,
170    // raw password in clear-text form.
171    ClearText(Vec<u8>),
172    // password encrypted with random salt.
173    Md5WithSalt {
174        encrypted_password: Vec<u8>,
175        salt: [u8; 4],
176    },
177    OAuth(HashMap<String, String>),
178}
179
180/// A JWK Set is a JSON object that represents a set of JWKs.
181/// The JSON object MUST have a "keys" member, with its value being an array of JWKs.
182/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-5> for more details.
183#[derive(Debug, Deserialize)]
184struct Jwks {
185    keys: Vec<Jwk>,
186}
187
188/// A JSON Web Key (JWK) is a JSON object that represents a cryptographic key.
189/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-4> for more details.
190#[derive(Debug, Deserialize)]
191struct Jwk {
192    kid: String, // Key ID
193    alg: String, // Algorithm
194    n: String,   // Modulus
195    e: String,   // Exponent
196}
197
198async fn validate_jwt(
199    jwt: &str,
200    jwks_url: &str,
201    issuer: &str,
202    metadata: &HashMap<String, String>,
203) -> Result<bool, BoxedError> {
204    let header = decode_header(jwt)?;
205    let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
206
207    // 1. Retrieve the kid from the header to find the right JWK in the JWK Set.
208    let kid = header.kid.ok_or("kid not found in jwt header")?;
209    let jwk = jwks
210        .keys
211        .into_iter()
212        .find(|k| k.kid == kid)
213        .ok_or("kid not found in jwks")?;
214
215    // 2. Check if the algorithms are matched.
216    if Algorithm::from_str(&jwk.alg)? != header.alg {
217        return Err("alg in jwt header does not match with alg in jwk".into());
218    }
219
220    // 3. Decode the JWT and validate the claims.
221    let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
222    let mut validation = Validation::new(header.alg);
223    validation.set_issuer(&[issuer]);
224    validation.set_required_spec_claims(&["exp", "iss"]);
225    let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
226
227    // 4. Check if the metadata in the token matches.
228    if !metadata.iter().all(
229        |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
230    ) {
231        return Err("metadata in jwt does not match with metadata declared with user".into());
232    }
233    Ok(true)
234}
235
236impl UserAuthenticator {
237    pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
238        let success = match self {
239            UserAuthenticator::None => true,
240            UserAuthenticator::ClearText(text) => password == text,
241            UserAuthenticator::Md5WithSalt {
242                encrypted_password, ..
243            } => encrypted_password == password,
244            UserAuthenticator::OAuth(metadata) => {
245                let mut metadata = metadata.clone();
246                let jwks_url = metadata.remove("jwks_url").unwrap();
247                let issuer = metadata.remove("issuer").unwrap();
248                validate_jwt(
249                    &String::from_utf8_lossy(password),
250                    &jwks_url,
251                    &issuer,
252                    &metadata,
253                )
254                .await
255                .map_err(PsqlError::StartupError)?
256            }
257        };
258        if !success {
259            return Err(PsqlError::PasswordError);
260        }
261        Ok(())
262    }
263}
264
265/// Binds a Tcp or Unix listener at `addr`. Spawn a coroutine to serve every new connection.
266///
267/// Returns when the `shutdown` token is triggered.
268pub async fn pg_serve(
269    addr: &str,
270    tcp_keepalive: TcpKeepalive,
271    session_mgr: Arc<impl SessionManager>,
272    context: ConnectionContext,
273    shutdown: CancellationToken,
274) -> Result<(), BoxedError> {
275    let listener = Listener::bind(addr).await?;
276    tracing::info!(addr, "server started");
277
278    let acceptor_runtime = BackgroundShutdownRuntime::from({
279        let mut builder = tokio::runtime::Builder::new_multi_thread();
280        builder.worker_threads(1);
281        builder
282            .thread_name("rw-acceptor")
283            .enable_all()
284            .build()
285            .unwrap()
286    });
287
288    #[cfg(not(madsim))]
289    let worker_runtime = tokio::runtime::Handle::current();
290    #[cfg(madsim)]
291    let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
292    let session_mgr_clone = session_mgr.clone();
293    let f = async move {
294        loop {
295            let conn_ret = listener.accept(&tcp_keepalive).await;
296            match conn_ret {
297                Ok((stream, peer_addr)) => {
298                    tracing::info!(%peer_addr, "accept connection");
299                    worker_runtime.spawn(handle_connection(
300                        stream,
301                        session_mgr_clone.clone(),
302                        Arc::new(peer_addr),
303                        context.clone(),
304                    ));
305                }
306
307                Err(e) => {
308                    tracing::error!(error = %e.as_report(), "failed to accept connection",);
309                }
310            }
311        }
312    };
313    acceptor_runtime.spawn(f);
314
315    // Wait for the shutdown signal.
316    shutdown.cancelled().await;
317
318    // Stop accepting new connections.
319    drop(acceptor_runtime);
320    // Shutdown session manager, typically close all existing sessions.
321    session_mgr.shutdown().await;
322
323    Ok(())
324}
325
326pub async fn handle_connection<S, SM>(
327    stream: S,
328    session_mgr: Arc<SM>,
329    peer_addr: AddressRef,
330    context: ConnectionContext,
331) where
332    S: PgByteStream,
333    SM: SessionManager,
334{
335    PgProtocol::new(stream, session_mgr, peer_addr, context)
336        .run()
337        .await;
338}
339#[cfg(test)]
340mod tests {
341    use std::error::Error;
342    use std::sync::Arc;
343    use std::time::Instant;
344
345    use bytes::Bytes;
346    use futures::StreamExt;
347    use futures::stream::BoxStream;
348    use risingwave_common::types::DataType;
349    use risingwave_common::util::tokio_util::sync::CancellationToken;
350    use risingwave_sqlparser::ast::Statement;
351    use tokio_postgres::NoTls;
352
353    use crate::error::PsqlResult;
354    use crate::memory_manager::MessageMemoryManager;
355    use crate::pg_field_descriptor::PgFieldDescriptor;
356    use crate::pg_message::TransactionStatus;
357    use crate::pg_protocol::ConnectionContext;
358    use crate::pg_response::{PgResponse, RowSetResult, StatementType};
359    use crate::pg_server::{
360        BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
361        UserAuthenticator, pg_serve,
362    };
363    use crate::types;
364    use crate::types::Row;
365
366    struct MockSessionManager {}
367    struct MockSession {}
368
369    impl SessionManager for MockSessionManager {
370        type Session = MockSession;
371
372        fn create_dummy_session(
373            &self,
374            _database_id: u32,
375            _user_name: u32,
376        ) -> Result<Arc<Self::Session>, BoxedError> {
377            unimplemented!()
378        }
379
380        fn connect(
381            &self,
382            _database: &str,
383            _user_name: &str,
384            _peer_addr: crate::net::AddressRef,
385        ) -> Result<Arc<Self::Session>, Box<dyn Error + Send + Sync>> {
386            Ok(Arc::new(MockSession {}))
387        }
388
389        fn cancel_queries_in_session(&self, _session_id: SessionId) {
390            todo!()
391        }
392
393        fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
394            todo!()
395        }
396
397        fn end_session(&self, _session: &Self::Session) {}
398    }
399
400    impl Session for MockSession {
401        type Portal = String;
402        type PreparedStatement = String;
403        type ValuesStream = BoxStream<'static, RowSetResult>;
404
405        async fn run_one_query(
406            self: Arc<Self>,
407            _stmt: Statement,
408            _format: types::Format,
409        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
410            Ok(PgResponse::builder(StatementType::SELECT)
411                .values(
412                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
413                        .boxed(),
414                    vec![
415                        // 1043 is the oid of varchar type.
416                        // -1 is the type len of varchar type.
417                        PgFieldDescriptor::new("".to_owned(), 1043, -1);
418                        1
419                    ],
420                )
421                .into())
422        }
423
424        async fn parse(
425            self: Arc<Self>,
426            _sql: Option<Statement>,
427            _params_types: Vec<Option<DataType>>,
428        ) -> Result<String, BoxedError> {
429            Ok(String::new())
430        }
431
432        fn bind(
433            self: Arc<Self>,
434            _prepare_statement: String,
435            _params: Vec<Option<Bytes>>,
436            _param_formats: Vec<types::Format>,
437            _result_formats: Vec<types::Format>,
438        ) -> Result<String, BoxedError> {
439            Ok(String::new())
440        }
441
442        async fn execute(
443            self: Arc<Self>,
444            _portal: String,
445        ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
446            Ok(PgResponse::builder(StatementType::SELECT)
447                .values(
448                    futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
449                        .boxed(),
450                    vec![
451                    // 1043 is the oid of varchar type.
452                    // -1 is the type len of varchar type.
453                    PgFieldDescriptor::new("".to_owned(), 1043, -1);
454                    1
455                ],
456                )
457                .into())
458        }
459
460        fn describe_statement(
461            self: Arc<Self>,
462            _statement: String,
463        ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
464            Ok((
465                vec![],
466                vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
467            ))
468        }
469
470        fn describe_portal(
471            self: Arc<Self>,
472            _portal: String,
473        ) -> Result<Vec<PgFieldDescriptor>, BoxedError> {
474            Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
475        }
476
477        fn user_authenticator(&self) -> &UserAuthenticator {
478            &UserAuthenticator::None
479        }
480
481        fn id(&self) -> SessionId {
482            (0, 0)
483        }
484
485        fn get_config(&self, key: &str) -> Result<String, BoxedError> {
486            match key {
487                "timezone" => Ok("UTC".to_owned()),
488                _ => Err(format!("Unknown config key: {key}").into()),
489            }
490        }
491
492        fn set_config(&self, _key: &str, _value: String) -> Result<String, BoxedError> {
493            Ok("".to_owned())
494        }
495
496        async fn next_notice(self: &Arc<Self>) -> String {
497            std::future::pending().await
498        }
499
500        fn transaction_status(&self) -> TransactionStatus {
501            TransactionStatus::Idle
502        }
503
504        fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
505            let exec_context = Arc::new(ExecContext {
506                running_sql: sql,
507                last_instant: Instant::now(),
508                last_idle_instant: Default::default(),
509            });
510            ExecContextGuard::new(exec_context)
511        }
512
513        fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
514            Ok(())
515        }
516    }
517
518    async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
519        let bind_addr = bind_addr.into();
520        let pg_config = pg_config.into();
521
522        let session_mgr = MockSessionManager {};
523        tokio::spawn(async move {
524            pg_serve(
525                &bind_addr,
526                socket2::TcpKeepalive::new(),
527                Arc::new(session_mgr),
528                ConnectionContext {
529                    tls_config: None,
530                    redact_sql_option_keywords: None,
531                    message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
532                        .into(),
533                },
534                CancellationToken::new(), // dummy
535            )
536            .await
537        });
538        // wait for server to start
539        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
540
541        // Connect to the database.
542        let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
543
544        // The connection object performs the actual communication with the database,
545        // so spawn it off to run on its own.
546        tokio::spawn(async move {
547            if let Err(e) = connection.await {
548                eprintln!("connection error: {}", e);
549            }
550        });
551
552        let rows = client
553            .simple_query("SELECT ''")
554            .await
555            .expect("Error executing query");
556        // Row + CommandComplete
557        assert_eq!(rows.len(), 2);
558
559        let rows = client
560            .query("SELECT ''", &[])
561            .await
562            .expect("Error executing query");
563        assert_eq!(rows.len(), 1);
564    }
565
566    #[tokio::test]
567    async fn test_query_tcp() {
568        do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
569    }
570
571    #[cfg(not(madsim))]
572    #[tokio::test]
573    async fn test_query_unix() {
574        let port: i16 = 10000;
575        let dir = tempfile::TempDir::new().unwrap();
576        let sock = dir.path().join(format!(".s.PGSQL.{port}"));
577
578        do_test_query(
579            format!("unix:{}", sock.to_str().unwrap()),
580            format!("host={} port={}", dir.path().to_str().unwrap(), port),
581        )
582        .await;
583    }
584}