1use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::types::DataType;
25use risingwave_common::util::runtime::BackgroundShutdownRuntime;
26use risingwave_common::util::tokio_util::sync::CancellationToken;
27use risingwave_jni_core::jvm_runtime::register_jvm_builder;
28use risingwave_sqlparser::ast::Statement;
29use serde::Deserialize;
30use thiserror_ext::AsReport;
31
32use crate::error::{PsqlError, PsqlResult};
33use crate::net::{AddressRef, Listener, TcpKeepalive};
34use crate::pg_field_descriptor::PgFieldDescriptor;
35use crate::pg_message::TransactionStatus;
36use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
37use crate::pg_response::{PgResponse, ValuesStream};
38use crate::types::Format;
39
40pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
41type ProcessId = i32;
42type SecretKey = i32;
43pub type SessionId = (ProcessId, SecretKey);
44
45pub trait SessionManager: Send + Sync + 'static {
48 type Session: Session;
49
50 fn create_dummy_session(
53 &self,
54 database_id: u32,
55 user_id: u32,
56 ) -> Result<Arc<Self::Session>, BoxedError>;
57
58 fn connect(
59 &self,
60 database: &str,
61 user_name: &str,
62 peer_addr: AddressRef,
63 ) -> Result<Arc<Self::Session>, BoxedError>;
64
65 fn cancel_queries_in_session(&self, session_id: SessionId);
66
67 fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
68
69 fn end_session(&self, session: &Self::Session);
70
71 fn shutdown(&self) -> impl Future<Output = ()> + Send {
73 async {}
74 }
75}
76
77pub trait Session: Send + Sync {
80 type ValuesStream: ValuesStream;
81 type PreparedStatement: Send + Clone + 'static;
82 type Portal: Send + Clone + std::fmt::Display + 'static;
83
84 fn run_one_query(
87 self: Arc<Self>,
88 stmt: Statement,
89 format: Format,
90 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
91
92 fn parse(
93 self: Arc<Self>,
94 sql: Option<Statement>,
95 params_types: Vec<Option<DataType>>,
96 ) -> impl Future<Output = Result<Self::PreparedStatement, BoxedError>> + Send;
97
98 fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
102
103 fn bind(
104 self: Arc<Self>,
105 prepare_statement: Self::PreparedStatement,
106 params: Vec<Option<Bytes>>,
107 param_formats: Vec<Format>,
108 result_formats: Vec<Format>,
109 ) -> Result<Self::Portal, BoxedError>;
110
111 fn execute(
112 self: Arc<Self>,
113 portal: Self::Portal,
114 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
115
116 fn describe_statement(
117 self: Arc<Self>,
118 prepare_statement: Self::PreparedStatement,
119 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError>;
120
121 fn describe_portal(
122 self: Arc<Self>,
123 portal: Self::Portal,
124 ) -> Result<Vec<PgFieldDescriptor>, BoxedError>;
125
126 fn user_authenticator(&self) -> &UserAuthenticator;
127
128 fn id(&self) -> SessionId;
129
130 fn get_config(&self, key: &str) -> Result<String, BoxedError>;
131
132 fn set_config(&self, key: &str, value: String) -> Result<String, BoxedError>;
133
134 fn transaction_status(&self) -> TransactionStatus;
135
136 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
137
138 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
139}
140
141pub struct ExecContext {
144 pub running_sql: Arc<str>,
145 pub last_instant: Instant,
147 pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
149}
150
151pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
154
155impl ExecContextGuard {
156 pub fn new(exec_context: Arc<ExecContext>) -> Self {
157 Self(exec_context)
158 }
159}
160
161impl Drop for ExecContext {
162 fn drop(&mut self) {
163 *self.last_idle_instant.lock() = Some(Instant::now());
164 }
165}
166
167#[derive(Debug, Clone)]
168pub enum UserAuthenticator {
169 None,
171 ClearText(Vec<u8>),
173 Md5WithSalt {
175 encrypted_password: Vec<u8>,
176 salt: [u8; 4],
177 },
178 OAuth(HashMap<String, String>),
179}
180
181#[derive(Debug, Deserialize)]
185struct Jwks {
186 keys: Vec<Jwk>,
187}
188
189#[derive(Debug, Deserialize)]
192struct Jwk {
193 kid: String, alg: String, n: String, e: String, }
198
199async fn validate_jwt(
200 jwt: &str,
201 jwks_url: &str,
202 issuer: &str,
203 metadata: &HashMap<String, String>,
204) -> Result<bool, BoxedError> {
205 let header = decode_header(jwt)?;
206 let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
207
208 let kid = header.kid.ok_or("kid not found in jwt header")?;
210 let jwk = jwks
211 .keys
212 .into_iter()
213 .find(|k| k.kid == kid)
214 .ok_or("kid not found in jwks")?;
215
216 if Algorithm::from_str(&jwk.alg)? != header.alg {
218 return Err("alg in jwt header does not match with alg in jwk".into());
219 }
220
221 let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
223 let mut validation = Validation::new(header.alg);
224 validation.set_issuer(&[issuer]);
225 validation.set_required_spec_claims(&["exp", "iss"]);
226 let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
227
228 if !metadata.iter().all(
230 |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
231 ) {
232 return Err("metadata in jwt does not match with metadata declared with user".into());
233 }
234 Ok(true)
235}
236
237impl UserAuthenticator {
238 pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
239 let success = match self {
240 UserAuthenticator::None => true,
241 UserAuthenticator::ClearText(text) => password == text,
242 UserAuthenticator::Md5WithSalt {
243 encrypted_password, ..
244 } => encrypted_password == password,
245 UserAuthenticator::OAuth(metadata) => {
246 let mut metadata = metadata.clone();
247 let jwks_url = metadata.remove("jwks_url").unwrap();
248 let issuer = metadata.remove("issuer").unwrap();
249 validate_jwt(
250 &String::from_utf8_lossy(password),
251 &jwks_url,
252 &issuer,
253 &metadata,
254 )
255 .await
256 .map_err(PsqlError::StartupError)?
257 }
258 };
259 if !success {
260 return Err(PsqlError::PasswordError);
261 }
262 Ok(())
263 }
264}
265
266pub async fn pg_serve(
270 addr: &str,
271 tcp_keepalive: TcpKeepalive,
272 session_mgr: Arc<impl SessionManager>,
273 context: ConnectionContext,
274 shutdown: CancellationToken,
275) -> Result<(), BoxedError> {
276 let listener = Listener::bind(addr).await?;
277 tracing::info!(addr, "server started");
278 register_jvm_builder();
279 let acceptor_runtime = BackgroundShutdownRuntime::from({
280 let mut builder = tokio::runtime::Builder::new_multi_thread();
281 builder.worker_threads(1);
282 builder
283 .thread_name("rw-acceptor")
284 .enable_all()
285 .build()
286 .unwrap()
287 });
288
289 #[cfg(not(madsim))]
290 let worker_runtime = tokio::runtime::Handle::current();
291 #[cfg(madsim)]
292 let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
293 let session_mgr_clone = session_mgr.clone();
294 let f = async move {
295 loop {
296 let conn_ret = listener.accept(&tcp_keepalive).await;
297 match conn_ret {
298 Ok((stream, peer_addr)) => {
299 tracing::info!(%peer_addr, "accept connection");
300 worker_runtime.spawn(handle_connection(
301 stream,
302 session_mgr_clone.clone(),
303 Arc::new(peer_addr),
304 context.clone(),
305 ));
306 }
307
308 Err(e) => {
309 tracing::error!(error = %e.as_report(), "failed to accept connection",);
310 }
311 }
312 }
313 };
314 acceptor_runtime.spawn(f);
315
316 shutdown.cancelled().await;
318
319 drop(acceptor_runtime);
321 session_mgr.shutdown().await;
323
324 Ok(())
325}
326
327pub async fn handle_connection<S, SM>(
328 stream: S,
329 session_mgr: Arc<SM>,
330 peer_addr: AddressRef,
331 context: ConnectionContext,
332) where
333 S: PgByteStream,
334 SM: SessionManager,
335{
336 PgProtocol::new(stream, session_mgr, peer_addr, context)
337 .run()
338 .await;
339}
340#[cfg(test)]
341mod tests {
342 use std::error::Error;
343 use std::sync::Arc;
344 use std::time::Instant;
345
346 use bytes::Bytes;
347 use futures::StreamExt;
348 use futures::stream::BoxStream;
349 use risingwave_common::types::DataType;
350 use risingwave_common::util::tokio_util::sync::CancellationToken;
351 use risingwave_sqlparser::ast::Statement;
352 use tokio_postgres::NoTls;
353
354 use crate::error::PsqlResult;
355 use crate::memory_manager::MessageMemoryManager;
356 use crate::pg_field_descriptor::PgFieldDescriptor;
357 use crate::pg_message::TransactionStatus;
358 use crate::pg_protocol::ConnectionContext;
359 use crate::pg_response::{PgResponse, RowSetResult, StatementType};
360 use crate::pg_server::{
361 BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
362 UserAuthenticator, pg_serve,
363 };
364 use crate::types;
365 use crate::types::Row;
366
367 struct MockSessionManager {}
368 struct MockSession {}
369
370 impl SessionManager for MockSessionManager {
371 type Session = MockSession;
372
373 fn create_dummy_session(
374 &self,
375 _database_id: u32,
376 _user_name: u32,
377 ) -> Result<Arc<Self::Session>, BoxedError> {
378 unimplemented!()
379 }
380
381 fn connect(
382 &self,
383 _database: &str,
384 _user_name: &str,
385 _peer_addr: crate::net::AddressRef,
386 ) -> Result<Arc<Self::Session>, Box<dyn Error + Send + Sync>> {
387 Ok(Arc::new(MockSession {}))
388 }
389
390 fn cancel_queries_in_session(&self, _session_id: SessionId) {
391 todo!()
392 }
393
394 fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
395 todo!()
396 }
397
398 fn end_session(&self, _session: &Self::Session) {}
399 }
400
401 impl Session for MockSession {
402 type Portal = String;
403 type PreparedStatement = String;
404 type ValuesStream = BoxStream<'static, RowSetResult>;
405
406 async fn run_one_query(
407 self: Arc<Self>,
408 _stmt: Statement,
409 _format: types::Format,
410 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
411 Ok(PgResponse::builder(StatementType::SELECT)
412 .values(
413 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
414 .boxed(),
415 vec![
416 PgFieldDescriptor::new("".to_owned(), 1043, -1);
419 1
420 ],
421 )
422 .into())
423 }
424
425 async fn parse(
426 self: Arc<Self>,
427 _sql: Option<Statement>,
428 _params_types: Vec<Option<DataType>>,
429 ) -> Result<String, BoxedError> {
430 Ok(String::new())
431 }
432
433 fn bind(
434 self: Arc<Self>,
435 _prepare_statement: String,
436 _params: Vec<Option<Bytes>>,
437 _param_formats: Vec<types::Format>,
438 _result_formats: Vec<types::Format>,
439 ) -> Result<String, BoxedError> {
440 Ok(String::new())
441 }
442
443 async fn execute(
444 self: Arc<Self>,
445 _portal: String,
446 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
447 Ok(PgResponse::builder(StatementType::SELECT)
448 .values(
449 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
450 .boxed(),
451 vec![
452 PgFieldDescriptor::new("".to_owned(), 1043, -1);
455 1
456 ],
457 )
458 .into())
459 }
460
461 fn describe_statement(
462 self: Arc<Self>,
463 _statement: String,
464 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
465 Ok((
466 vec![],
467 vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
468 ))
469 }
470
471 fn describe_portal(
472 self: Arc<Self>,
473 _portal: String,
474 ) -> Result<Vec<PgFieldDescriptor>, BoxedError> {
475 Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
476 }
477
478 fn user_authenticator(&self) -> &UserAuthenticator {
479 &UserAuthenticator::None
480 }
481
482 fn id(&self) -> SessionId {
483 (0, 0)
484 }
485
486 fn get_config(&self, key: &str) -> Result<String, BoxedError> {
487 match key {
488 "timezone" => Ok("UTC".to_owned()),
489 _ => Err(format!("Unknown config key: {key}").into()),
490 }
491 }
492
493 fn set_config(&self, _key: &str, _value: String) -> Result<String, BoxedError> {
494 Ok("".to_owned())
495 }
496
497 async fn next_notice(self: &Arc<Self>) -> String {
498 std::future::pending().await
499 }
500
501 fn transaction_status(&self) -> TransactionStatus {
502 TransactionStatus::Idle
503 }
504
505 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
506 let exec_context = Arc::new(ExecContext {
507 running_sql: sql,
508 last_instant: Instant::now(),
509 last_idle_instant: Default::default(),
510 });
511 ExecContextGuard::new(exec_context)
512 }
513
514 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
515 Ok(())
516 }
517 }
518
519 async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
520 let bind_addr = bind_addr.into();
521 let pg_config = pg_config.into();
522
523 let session_mgr = MockSessionManager {};
524 tokio::spawn(async move {
525 pg_serve(
526 &bind_addr,
527 socket2::TcpKeepalive::new(),
528 Arc::new(session_mgr),
529 ConnectionContext {
530 tls_config: None,
531 redact_sql_option_keywords: None,
532 message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
533 .into(),
534 },
535 CancellationToken::new(), )
537 .await
538 });
539 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
541
542 let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
544
545 tokio::spawn(async move {
548 if let Err(e) = connection.await {
549 eprintln!("connection error: {}", e);
550 }
551 });
552
553 let rows = client
554 .simple_query("SELECT ''")
555 .await
556 .expect("Error executing query");
557 assert_eq!(rows.len(), 2);
559
560 let rows = client
561 .query("SELECT ''", &[])
562 .await
563 .expect("Error executing query");
564 assert_eq!(rows.len(), 1);
565 }
566
567 #[tokio::test]
568 async fn test_query_tcp() {
569 do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
570 }
571
572 #[cfg(not(madsim))]
573 #[tokio::test]
574 async fn test_query_unix() {
575 let port: i16 = 10000;
576 let dir = tempfile::TempDir::new().unwrap();
577 let sock = dir.path().join(format!(".s.PGSQL.{port}"));
578
579 do_test_query(
580 format!("unix:{}", sock.to_str().unwrap()),
581 format!("host={} port={}", dir.path().to_str().unwrap(), port),
582 )
583 .await;
584 }
585}