1use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::types::DataType;
25use risingwave_common::util::runtime::BackgroundShutdownRuntime;
26use risingwave_common::util::tokio_util::sync::CancellationToken;
27use risingwave_sqlparser::ast::Statement;
28use serde::Deserialize;
29use thiserror_ext::AsReport;
30
31use crate::error::{PsqlError, PsqlResult};
32use crate::net::{AddressRef, Listener, TcpKeepalive};
33use crate::pg_field_descriptor::PgFieldDescriptor;
34use crate::pg_message::TransactionStatus;
35use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
36use crate::pg_response::{PgResponse, ValuesStream};
37use crate::types::Format;
38
39pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
40type ProcessId = i32;
41type SecretKey = i32;
42pub type SessionId = (ProcessId, SecretKey);
43
44pub trait SessionManager: Send + Sync + 'static {
47 type Session: Session;
48
49 fn create_dummy_session(
52 &self,
53 database_id: u32,
54 user_id: u32,
55 ) -> Result<Arc<Self::Session>, BoxedError>;
56
57 fn connect(
58 &self,
59 database: &str,
60 user_name: &str,
61 peer_addr: AddressRef,
62 ) -> Result<Arc<Self::Session>, BoxedError>;
63
64 fn cancel_queries_in_session(&self, session_id: SessionId);
65
66 fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
67
68 fn end_session(&self, session: &Self::Session);
69
70 fn shutdown(&self) -> impl Future<Output = ()> + Send {
72 async {}
73 }
74}
75
76pub trait Session: Send + Sync {
79 type ValuesStream: ValuesStream;
80 type PreparedStatement: Send + Clone + 'static;
81 type Portal: Send + Clone + std::fmt::Display + 'static;
82
83 fn run_one_query(
86 self: Arc<Self>,
87 stmt: Statement,
88 format: Format,
89 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
90
91 fn parse(
92 self: Arc<Self>,
93 sql: Option<Statement>,
94 params_types: Vec<Option<DataType>>,
95 ) -> impl Future<Output = Result<Self::PreparedStatement, BoxedError>> + Send;
96
97 fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
101
102 fn bind(
103 self: Arc<Self>,
104 prepare_statement: Self::PreparedStatement,
105 params: Vec<Option<Bytes>>,
106 param_formats: Vec<Format>,
107 result_formats: Vec<Format>,
108 ) -> Result<Self::Portal, BoxedError>;
109
110 fn execute(
111 self: Arc<Self>,
112 portal: Self::Portal,
113 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
114
115 fn describe_statement(
116 self: Arc<Self>,
117 prepare_statement: Self::PreparedStatement,
118 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError>;
119
120 fn describe_portal(
121 self: Arc<Self>,
122 portal: Self::Portal,
123 ) -> Result<Vec<PgFieldDescriptor>, BoxedError>;
124
125 fn user_authenticator(&self) -> &UserAuthenticator;
126
127 fn id(&self) -> SessionId;
128
129 fn get_config(&self, key: &str) -> Result<String, BoxedError>;
130
131 fn set_config(&self, key: &str, value: String) -> Result<String, BoxedError>;
132
133 fn transaction_status(&self) -> TransactionStatus;
134
135 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
136
137 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
138}
139
140pub struct ExecContext {
143 pub running_sql: Arc<str>,
144 pub last_instant: Instant,
146 pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
148}
149
150pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
153
154impl ExecContextGuard {
155 pub fn new(exec_context: Arc<ExecContext>) -> Self {
156 Self(exec_context)
157 }
158}
159
160impl Drop for ExecContext {
161 fn drop(&mut self) {
162 *self.last_idle_instant.lock() = Some(Instant::now());
163 }
164}
165
166#[derive(Debug, Clone)]
167pub enum UserAuthenticator {
168 None,
170 ClearText(Vec<u8>),
172 Md5WithSalt {
174 encrypted_password: Vec<u8>,
175 salt: [u8; 4],
176 },
177 OAuth(HashMap<String, String>),
178}
179
180#[derive(Debug, Deserialize)]
184struct Jwks {
185 keys: Vec<Jwk>,
186}
187
188#[derive(Debug, Deserialize)]
191struct Jwk {
192 kid: String, alg: String, n: String, e: String, }
197
198async fn validate_jwt(
199 jwt: &str,
200 jwks_url: &str,
201 issuer: &str,
202 metadata: &HashMap<String, String>,
203) -> Result<bool, BoxedError> {
204 let header = decode_header(jwt)?;
205 let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
206
207 let kid = header.kid.ok_or("kid not found in jwt header")?;
209 let jwk = jwks
210 .keys
211 .into_iter()
212 .find(|k| k.kid == kid)
213 .ok_or("kid not found in jwks")?;
214
215 if Algorithm::from_str(&jwk.alg)? != header.alg {
217 return Err("alg in jwt header does not match with alg in jwk".into());
218 }
219
220 let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
222 let mut validation = Validation::new(header.alg);
223 validation.set_issuer(&[issuer]);
224 validation.set_required_spec_claims(&["exp", "iss"]);
225 let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
226
227 if !metadata.iter().all(
229 |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
230 ) {
231 return Err("metadata in jwt does not match with metadata declared with user".into());
232 }
233 Ok(true)
234}
235
236impl UserAuthenticator {
237 pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
238 let success = match self {
239 UserAuthenticator::None => true,
240 UserAuthenticator::ClearText(text) => password == text,
241 UserAuthenticator::Md5WithSalt {
242 encrypted_password, ..
243 } => encrypted_password == password,
244 UserAuthenticator::OAuth(metadata) => {
245 let mut metadata = metadata.clone();
246 let jwks_url = metadata.remove("jwks_url").unwrap();
247 let issuer = metadata.remove("issuer").unwrap();
248 validate_jwt(
249 &String::from_utf8_lossy(password),
250 &jwks_url,
251 &issuer,
252 &metadata,
253 )
254 .await
255 .map_err(PsqlError::StartupError)?
256 }
257 };
258 if !success {
259 return Err(PsqlError::PasswordError);
260 }
261 Ok(())
262 }
263}
264
265pub async fn pg_serve(
269 addr: &str,
270 tcp_keepalive: TcpKeepalive,
271 session_mgr: Arc<impl SessionManager>,
272 context: ConnectionContext,
273 shutdown: CancellationToken,
274) -> Result<(), BoxedError> {
275 let listener = Listener::bind(addr).await?;
276 tracing::info!(addr, "server started");
277
278 let acceptor_runtime = BackgroundShutdownRuntime::from({
279 let mut builder = tokio::runtime::Builder::new_multi_thread();
280 builder.worker_threads(1);
281 builder
282 .thread_name("rw-acceptor")
283 .enable_all()
284 .build()
285 .unwrap()
286 });
287
288 #[cfg(not(madsim))]
289 let worker_runtime = tokio::runtime::Handle::current();
290 #[cfg(madsim)]
291 let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
292 let session_mgr_clone = session_mgr.clone();
293 let f = async move {
294 loop {
295 let conn_ret = listener.accept(&tcp_keepalive).await;
296 match conn_ret {
297 Ok((stream, peer_addr)) => {
298 tracing::info!(%peer_addr, "accept connection");
299 worker_runtime.spawn(handle_connection(
300 stream,
301 session_mgr_clone.clone(),
302 Arc::new(peer_addr),
303 context.clone(),
304 ));
305 }
306
307 Err(e) => {
308 tracing::error!(error = %e.as_report(), "failed to accept connection",);
309 }
310 }
311 }
312 };
313 acceptor_runtime.spawn(f);
314
315 shutdown.cancelled().await;
317
318 drop(acceptor_runtime);
320 session_mgr.shutdown().await;
322
323 Ok(())
324}
325
326pub async fn handle_connection<S, SM>(
327 stream: S,
328 session_mgr: Arc<SM>,
329 peer_addr: AddressRef,
330 context: ConnectionContext,
331) where
332 S: PgByteStream,
333 SM: SessionManager,
334{
335 PgProtocol::new(stream, session_mgr, peer_addr, context)
336 .run()
337 .await;
338}
339#[cfg(test)]
340mod tests {
341 use std::error::Error;
342 use std::sync::Arc;
343 use std::time::Instant;
344
345 use bytes::Bytes;
346 use futures::StreamExt;
347 use futures::stream::BoxStream;
348 use risingwave_common::types::DataType;
349 use risingwave_common::util::tokio_util::sync::CancellationToken;
350 use risingwave_sqlparser::ast::Statement;
351 use tokio_postgres::NoTls;
352
353 use crate::error::PsqlResult;
354 use crate::memory_manager::MessageMemoryManager;
355 use crate::pg_field_descriptor::PgFieldDescriptor;
356 use crate::pg_message::TransactionStatus;
357 use crate::pg_protocol::ConnectionContext;
358 use crate::pg_response::{PgResponse, RowSetResult, StatementType};
359 use crate::pg_server::{
360 BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
361 UserAuthenticator, pg_serve,
362 };
363 use crate::types;
364 use crate::types::Row;
365
366 struct MockSessionManager {}
367 struct MockSession {}
368
369 impl SessionManager for MockSessionManager {
370 type Session = MockSession;
371
372 fn create_dummy_session(
373 &self,
374 _database_id: u32,
375 _user_name: u32,
376 ) -> Result<Arc<Self::Session>, BoxedError> {
377 unimplemented!()
378 }
379
380 fn connect(
381 &self,
382 _database: &str,
383 _user_name: &str,
384 _peer_addr: crate::net::AddressRef,
385 ) -> Result<Arc<Self::Session>, Box<dyn Error + Send + Sync>> {
386 Ok(Arc::new(MockSession {}))
387 }
388
389 fn cancel_queries_in_session(&self, _session_id: SessionId) {
390 todo!()
391 }
392
393 fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
394 todo!()
395 }
396
397 fn end_session(&self, _session: &Self::Session) {}
398 }
399
400 impl Session for MockSession {
401 type Portal = String;
402 type PreparedStatement = String;
403 type ValuesStream = BoxStream<'static, RowSetResult>;
404
405 async fn run_one_query(
406 self: Arc<Self>,
407 _stmt: Statement,
408 _format: types::Format,
409 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
410 Ok(PgResponse::builder(StatementType::SELECT)
411 .values(
412 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
413 .boxed(),
414 vec![
415 PgFieldDescriptor::new("".to_owned(), 1043, -1);
418 1
419 ],
420 )
421 .into())
422 }
423
424 async fn parse(
425 self: Arc<Self>,
426 _sql: Option<Statement>,
427 _params_types: Vec<Option<DataType>>,
428 ) -> Result<String, BoxedError> {
429 Ok(String::new())
430 }
431
432 fn bind(
433 self: Arc<Self>,
434 _prepare_statement: String,
435 _params: Vec<Option<Bytes>>,
436 _param_formats: Vec<types::Format>,
437 _result_formats: Vec<types::Format>,
438 ) -> Result<String, BoxedError> {
439 Ok(String::new())
440 }
441
442 async fn execute(
443 self: Arc<Self>,
444 _portal: String,
445 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
446 Ok(PgResponse::builder(StatementType::SELECT)
447 .values(
448 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
449 .boxed(),
450 vec![
451 PgFieldDescriptor::new("".to_owned(), 1043, -1);
454 1
455 ],
456 )
457 .into())
458 }
459
460 fn describe_statement(
461 self: Arc<Self>,
462 _statement: String,
463 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
464 Ok((
465 vec![],
466 vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
467 ))
468 }
469
470 fn describe_portal(
471 self: Arc<Self>,
472 _portal: String,
473 ) -> Result<Vec<PgFieldDescriptor>, BoxedError> {
474 Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
475 }
476
477 fn user_authenticator(&self) -> &UserAuthenticator {
478 &UserAuthenticator::None
479 }
480
481 fn id(&self) -> SessionId {
482 (0, 0)
483 }
484
485 fn get_config(&self, key: &str) -> Result<String, BoxedError> {
486 match key {
487 "timezone" => Ok("UTC".to_owned()),
488 _ => Err(format!("Unknown config key: {key}").into()),
489 }
490 }
491
492 fn set_config(&self, _key: &str, _value: String) -> Result<String, BoxedError> {
493 Ok("".to_owned())
494 }
495
496 async fn next_notice(self: &Arc<Self>) -> String {
497 std::future::pending().await
498 }
499
500 fn transaction_status(&self) -> TransactionStatus {
501 TransactionStatus::Idle
502 }
503
504 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
505 let exec_context = Arc::new(ExecContext {
506 running_sql: sql,
507 last_instant: Instant::now(),
508 last_idle_instant: Default::default(),
509 });
510 ExecContextGuard::new(exec_context)
511 }
512
513 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
514 Ok(())
515 }
516 }
517
518 async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
519 let bind_addr = bind_addr.into();
520 let pg_config = pg_config.into();
521
522 let session_mgr = MockSessionManager {};
523 tokio::spawn(async move {
524 pg_serve(
525 &bind_addr,
526 socket2::TcpKeepalive::new(),
527 Arc::new(session_mgr),
528 ConnectionContext {
529 tls_config: None,
530 redact_sql_option_keywords: None,
531 message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
532 .into(),
533 },
534 CancellationToken::new(), )
536 .await
537 });
538 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
540
541 let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
543
544 tokio::spawn(async move {
547 if let Err(e) = connection.await {
548 eprintln!("connection error: {}", e);
549 }
550 });
551
552 let rows = client
553 .simple_query("SELECT ''")
554 .await
555 .expect("Error executing query");
556 assert_eq!(rows.len(), 2);
558
559 let rows = client
560 .query("SELECT ''", &[])
561 .await
562 .expect("Error executing query");
563 assert_eq!(rows.len(), 1);
564 }
565
566 #[tokio::test]
567 async fn test_query_tcp() {
568 do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
569 }
570
571 #[cfg(not(madsim))]
572 #[tokio::test]
573 async fn test_query_unix() {
574 let port: i16 = 10000;
575 let dir = tempfile::TempDir::new().unwrap();
576 let sock = dir.path().join(format!(".s.PGSQL.{port}"));
577
578 do_test_query(
579 format!("unix:{}", sock.to_str().unwrap()),
580 format!("host={} port={}", dir.path().to_str().unwrap(), port),
581 )
582 .await;
583 }
584}