1use std::collections::HashMap;
16use std::future::Future;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::time::Instant;
20
21use bytes::Bytes;
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
23use parking_lot::Mutex;
24use risingwave_common::types::DataType;
25use risingwave_common::util::runtime::BackgroundShutdownRuntime;
26use risingwave_common::util::tokio_util::sync::CancellationToken;
27use risingwave_sqlparser::ast::Statement;
28use serde::Deserialize;
29use thiserror_ext::AsReport;
30
31use crate::error::{PsqlError, PsqlResult};
32use crate::net::{AddressRef, Listener, TcpKeepalive};
33use crate::pg_field_descriptor::PgFieldDescriptor;
34use crate::pg_message::TransactionStatus;
35use crate::pg_protocol::{ConnectionContext, PgByteStream, PgProtocol};
36use crate::pg_response::{PgResponse, ValuesStream};
37use crate::types::Format;
38
39pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
40type ProcessId = i32;
41type SecretKey = i32;
42pub type SessionId = (ProcessId, SecretKey);
43
44pub trait SessionManager: Send + Sync + 'static {
47 type Session: Session;
48
49 fn create_dummy_session(
52 &self,
53 database_id: u32,
54 user_id: u32,
55 ) -> Result<Arc<Self::Session>, BoxedError>;
56
57 fn connect(
58 &self,
59 database: &str,
60 user_name: &str,
61 peer_addr: AddressRef,
62 ) -> Result<Arc<Self::Session>, BoxedError>;
63
64 fn cancel_queries_in_session(&self, session_id: SessionId);
65
66 fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
67
68 fn end_session(&self, session: &Self::Session);
69
70 fn shutdown(&self) -> impl Future<Output = ()> + Send {
72 async {}
73 }
74}
75
76pub trait Session: Send + Sync {
79 type ValuesStream: ValuesStream;
80 type PreparedStatement: Send + Clone + 'static;
81 type Portal: Send + Clone + std::fmt::Display + 'static;
82
83 fn run_one_query(
86 self: Arc<Self>,
87 stmt: Statement,
88 format: Format,
89 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
90
91 fn parse(
92 self: Arc<Self>,
93 sql: Option<Statement>,
94 params_types: Vec<Option<DataType>>,
95 ) -> impl Future<Output = Result<Self::PreparedStatement, BoxedError>> + Send;
96
97 fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;
101
102 fn bind(
103 self: Arc<Self>,
104 prepare_statement: Self::PreparedStatement,
105 params: Vec<Option<Bytes>>,
106 param_formats: Vec<Format>,
107 result_formats: Vec<Format>,
108 ) -> Result<Self::Portal, BoxedError>;
109
110 fn execute(
111 self: Arc<Self>,
112 portal: Self::Portal,
113 ) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
114
115 fn describe_statement(
116 self: Arc<Self>,
117 prepare_statement: Self::PreparedStatement,
118 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError>;
119
120 fn describe_portal(
121 self: Arc<Self>,
122 portal: Self::Portal,
123 ) -> Result<Vec<PgFieldDescriptor>, BoxedError>;
124
125 fn user_authenticator(&self) -> &UserAuthenticator;
126
127 fn id(&self) -> SessionId;
128
129 fn set_config(&self, key: &str, value: String) -> Result<String, BoxedError>;
130
131 fn transaction_status(&self) -> TransactionStatus;
132
133 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
134
135 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
136}
137
138pub struct ExecContext {
141 pub running_sql: Arc<str>,
142 pub last_instant: Instant,
144 pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
146}
147
148pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
151
152impl ExecContextGuard {
153 pub fn new(exec_context: Arc<ExecContext>) -> Self {
154 Self(exec_context)
155 }
156}
157
158impl Drop for ExecContext {
159 fn drop(&mut self) {
160 *self.last_idle_instant.lock() = Some(Instant::now());
161 }
162}
163
164#[derive(Debug, Clone)]
165pub enum UserAuthenticator {
166 None,
168 ClearText(Vec<u8>),
170 Md5WithSalt {
172 encrypted_password: Vec<u8>,
173 salt: [u8; 4],
174 },
175 OAuth(HashMap<String, String>),
176}
177
178#[derive(Debug, Deserialize)]
182struct Jwks {
183 keys: Vec<Jwk>,
184}
185
186#[derive(Debug, Deserialize)]
189struct Jwk {
190 kid: String, alg: String, n: String, e: String, }
195
196async fn validate_jwt(
197 jwt: &str,
198 jwks_url: &str,
199 issuer: &str,
200 metadata: &HashMap<String, String>,
201) -> Result<bool, BoxedError> {
202 let header = decode_header(jwt)?;
203 let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
204
205 let kid = header.kid.ok_or("kid not found in jwt header")?;
207 let jwk = jwks
208 .keys
209 .into_iter()
210 .find(|k| k.kid == kid)
211 .ok_or("kid not found in jwks")?;
212
213 if Algorithm::from_str(&jwk.alg)? != header.alg {
215 return Err("alg in jwt header does not match with alg in jwk".into());
216 }
217
218 let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
220 let mut validation = Validation::new(header.alg);
221 validation.set_issuer(&[issuer]);
222 validation.set_required_spec_claims(&["exp", "iss"]);
223 let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
224
225 if !metadata.iter().all(
227 |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
228 ) {
229 return Err("metadata in jwt does not match with metadata declared with user".into());
230 }
231 Ok(true)
232}
233
234impl UserAuthenticator {
235 pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
236 let success = match self {
237 UserAuthenticator::None => true,
238 UserAuthenticator::ClearText(text) => password == text,
239 UserAuthenticator::Md5WithSalt {
240 encrypted_password, ..
241 } => encrypted_password == password,
242 UserAuthenticator::OAuth(metadata) => {
243 let mut metadata = metadata.clone();
244 let jwks_url = metadata.remove("jwks_url").unwrap();
245 let issuer = metadata.remove("issuer").unwrap();
246 validate_jwt(
247 &String::from_utf8_lossy(password),
248 &jwks_url,
249 &issuer,
250 &metadata,
251 )
252 .await
253 .map_err(PsqlError::StartupError)?
254 }
255 };
256 if !success {
257 return Err(PsqlError::PasswordError);
258 }
259 Ok(())
260 }
261}
262
263pub async fn pg_serve(
267 addr: &str,
268 tcp_keepalive: TcpKeepalive,
269 session_mgr: Arc<impl SessionManager>,
270 context: ConnectionContext,
271 shutdown: CancellationToken,
272) -> Result<(), BoxedError> {
273 let listener = Listener::bind(addr).await?;
274 tracing::info!(addr, "server started");
275
276 let acceptor_runtime = BackgroundShutdownRuntime::from({
277 let mut builder = tokio::runtime::Builder::new_multi_thread();
278 builder.worker_threads(1);
279 builder
280 .thread_name("rw-acceptor")
281 .enable_all()
282 .build()
283 .unwrap()
284 });
285
286 #[cfg(not(madsim))]
287 let worker_runtime = tokio::runtime::Handle::current();
288 #[cfg(madsim)]
289 let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
290 let session_mgr_clone = session_mgr.clone();
291 let f = async move {
292 loop {
293 let conn_ret = listener.accept(&tcp_keepalive).await;
294 match conn_ret {
295 Ok((stream, peer_addr)) => {
296 tracing::info!(%peer_addr, "accept connection");
297 worker_runtime.spawn(handle_connection(
298 stream,
299 session_mgr_clone.clone(),
300 Arc::new(peer_addr),
301 context.clone(),
302 ));
303 }
304
305 Err(e) => {
306 tracing::error!(error = %e.as_report(), "failed to accept connection",);
307 }
308 }
309 }
310 };
311 acceptor_runtime.spawn(f);
312
313 shutdown.cancelled().await;
315
316 drop(acceptor_runtime);
318 session_mgr.shutdown().await;
320
321 Ok(())
322}
323
324pub async fn handle_connection<S, SM>(
325 stream: S,
326 session_mgr: Arc<SM>,
327 peer_addr: AddressRef,
328 context: ConnectionContext,
329) where
330 S: PgByteStream,
331 SM: SessionManager,
332{
333 PgProtocol::new(stream, session_mgr, peer_addr, context)
334 .run()
335 .await;
336}
337#[cfg(test)]
338mod tests {
339 use std::error::Error;
340 use std::sync::Arc;
341 use std::time::Instant;
342
343 use bytes::Bytes;
344 use futures::StreamExt;
345 use futures::stream::BoxStream;
346 use risingwave_common::types::DataType;
347 use risingwave_common::util::tokio_util::sync::CancellationToken;
348 use risingwave_sqlparser::ast::Statement;
349 use tokio_postgres::NoTls;
350
351 use crate::error::PsqlResult;
352 use crate::memory_manager::MessageMemoryManager;
353 use crate::pg_field_descriptor::PgFieldDescriptor;
354 use crate::pg_message::TransactionStatus;
355 use crate::pg_protocol::ConnectionContext;
356 use crate::pg_response::{PgResponse, RowSetResult, StatementType};
357 use crate::pg_server::{
358 BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
359 UserAuthenticator, pg_serve,
360 };
361 use crate::types;
362 use crate::types::Row;
363
364 struct MockSessionManager {}
365 struct MockSession {}
366
367 impl SessionManager for MockSessionManager {
368 type Session = MockSession;
369
370 fn create_dummy_session(
371 &self,
372 _database_id: u32,
373 _user_name: u32,
374 ) -> Result<Arc<Self::Session>, BoxedError> {
375 unimplemented!()
376 }
377
378 fn connect(
379 &self,
380 _database: &str,
381 _user_name: &str,
382 _peer_addr: crate::net::AddressRef,
383 ) -> Result<Arc<Self::Session>, Box<dyn Error + Send + Sync>> {
384 Ok(Arc::new(MockSession {}))
385 }
386
387 fn cancel_queries_in_session(&self, _session_id: SessionId) {
388 todo!()
389 }
390
391 fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
392 todo!()
393 }
394
395 fn end_session(&self, _session: &Self::Session) {}
396 }
397
398 impl Session for MockSession {
399 type Portal = String;
400 type PreparedStatement = String;
401 type ValuesStream = BoxStream<'static, RowSetResult>;
402
403 async fn run_one_query(
404 self: Arc<Self>,
405 _stmt: Statement,
406 _format: types::Format,
407 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
408 Ok(PgResponse::builder(StatementType::SELECT)
409 .values(
410 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
411 .boxed(),
412 vec![
413 PgFieldDescriptor::new("".to_owned(), 1043, -1);
416 1
417 ],
418 )
419 .into())
420 }
421
422 async fn parse(
423 self: Arc<Self>,
424 _sql: Option<Statement>,
425 _params_types: Vec<Option<DataType>>,
426 ) -> Result<String, BoxedError> {
427 Ok(String::new())
428 }
429
430 fn bind(
431 self: Arc<Self>,
432 _prepare_statement: String,
433 _params: Vec<Option<Bytes>>,
434 _param_formats: Vec<types::Format>,
435 _result_formats: Vec<types::Format>,
436 ) -> Result<String, BoxedError> {
437 Ok(String::new())
438 }
439
440 async fn execute(
441 self: Arc<Self>,
442 _portal: String,
443 ) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
444 Ok(PgResponse::builder(StatementType::SELECT)
445 .values(
446 futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
447 .boxed(),
448 vec![
449 PgFieldDescriptor::new("".to_owned(), 1043, -1);
452 1
453 ],
454 )
455 .into())
456 }
457
458 fn describe_statement(
459 self: Arc<Self>,
460 _statement: String,
461 ) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
462 Ok((
463 vec![],
464 vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)],
465 ))
466 }
467
468 fn describe_portal(
469 self: Arc<Self>,
470 _portal: String,
471 ) -> Result<Vec<PgFieldDescriptor>, BoxedError> {
472 Ok(vec![PgFieldDescriptor::new("".to_owned(), 1043, -1)])
473 }
474
475 fn user_authenticator(&self) -> &UserAuthenticator {
476 &UserAuthenticator::None
477 }
478
479 fn id(&self) -> SessionId {
480 (0, 0)
481 }
482
483 fn set_config(&self, _key: &str, _value: String) -> Result<String, BoxedError> {
484 Ok("".to_owned())
485 }
486
487 async fn next_notice(self: &Arc<Self>) -> String {
488 std::future::pending().await
489 }
490
491 fn transaction_status(&self) -> TransactionStatus {
492 TransactionStatus::Idle
493 }
494
495 fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
496 let exec_context = Arc::new(ExecContext {
497 running_sql: sql,
498 last_instant: Instant::now(),
499 last_idle_instant: Default::default(),
500 });
501 ExecContextGuard::new(exec_context)
502 }
503
504 fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
505 Ok(())
506 }
507 }
508
509 async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
510 let bind_addr = bind_addr.into();
511 let pg_config = pg_config.into();
512
513 let session_mgr = MockSessionManager {};
514 tokio::spawn(async move {
515 pg_serve(
516 &bind_addr,
517 socket2::TcpKeepalive::new(),
518 Arc::new(session_mgr),
519 ConnectionContext {
520 tls_config: None,
521 redact_sql_option_keywords: None,
522 message_memory_manager: MessageMemoryManager::new(u64::MAX, u64::MAX, u64::MAX)
523 .into(),
524 },
525 CancellationToken::new(), )
527 .await
528 });
529 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
531
532 let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
534
535 tokio::spawn(async move {
538 if let Err(e) = connection.await {
539 eprintln!("connection error: {}", e);
540 }
541 });
542
543 let rows = client
544 .simple_query("SELECT ''")
545 .await
546 .expect("Error executing query");
547 assert_eq!(rows.len(), 2);
549
550 let rows = client
551 .query("SELECT ''", &[])
552 .await
553 .expect("Error executing query");
554 assert_eq!(rows.len(), 1);
555 }
556
557 #[tokio::test]
558 async fn test_query_tcp() {
559 do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
560 }
561
562 #[cfg(not(madsim))]
563 #[tokio::test]
564 async fn test_query_unix() {
565 let port: i16 = 10000;
566 let dir = tempfile::TempDir::new().unwrap();
567 let sock = dir.path().join(format!(".s.PGSQL.{port}"));
568
569 do_test_query(
570 format!("unix:{}", sock.to_str().unwrap()),
571 format!("host={} port={}", dir.path().to_str().unwrap(), port),
572 )
573 .await;
574 }
575}