use std::collections::HashMap;
use std::future::Future;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Instant;
use bytes::Bytes;
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use parking_lot::Mutex;
use risingwave_common::types::DataType;
use risingwave_common::util::runtime::BackgroundShutdownRuntime;
use risingwave_common::util::tokio_util::sync::CancellationToken;
use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
use serde::Deserialize;
use thiserror_ext::AsReport;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::error::{PsqlError, PsqlResult};
use crate::net::{AddressRef, Listener, TcpKeepalive};
use crate::pg_field_descriptor::PgFieldDescriptor;
use crate::pg_message::TransactionStatus;
use crate::pg_protocol::{PgProtocol, TlsConfig};
use crate::pg_response::{PgResponse, ValuesStream};
use crate::types::Format;
pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
type ProcessId = i32;
type SecretKey = i32;
pub type SessionId = (ProcessId, SecretKey);
pub trait SessionManager: Send + Sync + 'static {
type Session: Session;
fn create_dummy_session(
&self,
database_id: u32,
user_id: u32,
) -> Result<Arc<Self::Session>, BoxedError>;
fn connect(
&self,
database: &str,
user_name: &str,
peer_addr: AddressRef,
) -> Result<Arc<Self::Session>, BoxedError>;
fn cancel_queries_in_session(&self, session_id: SessionId);
fn cancel_creating_jobs_in_session(&self, session_id: SessionId);
fn end_session(&self, session: &Self::Session);
fn shutdown(&self) -> impl Future<Output = ()> + Send {
async {}
}
}
pub trait Session: Send + Sync {
type ValuesStream: ValuesStream;
type PreparedStatement: Send + Clone + 'static;
type Portal: Send + Clone + std::fmt::Display + 'static;
fn run_one_query(
self: Arc<Self>,
stmt: Statement,
format: Format,
) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
fn parse(
self: Arc<Self>,
sql: Option<Statement>,
params_types: Vec<Option<DataType>>,
) -> impl Future<Output = Result<Self::PreparedStatement, BoxedError>> + Send;
fn take_notices(self: Arc<Self>) -> Vec<String>;
fn bind(
self: Arc<Self>,
prepare_statement: Self::PreparedStatement,
params: Vec<Option<Bytes>>,
param_formats: Vec<Format>,
result_formats: Vec<Format>,
) -> Result<Self::Portal, BoxedError>;
fn execute(
self: Arc<Self>,
portal: Self::Portal,
) -> impl Future<Output = Result<PgResponse<Self::ValuesStream>, BoxedError>> + Send;
fn describe_statement(
self: Arc<Self>,
prepare_statement: Self::PreparedStatement,
) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError>;
fn describe_portal(
self: Arc<Self>,
portal: Self::Portal,
) -> Result<Vec<PgFieldDescriptor>, BoxedError>;
fn user_authenticator(&self) -> &UserAuthenticator;
fn id(&self) -> SessionId;
fn set_config(&self, key: &str, value: String) -> Result<String, BoxedError>;
fn transaction_status(&self) -> TransactionStatus;
fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard;
fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()>;
}
pub struct ExecContext {
pub running_sql: Arc<str>,
pub last_instant: Instant,
pub last_idle_instant: Arc<Mutex<Option<Instant>>>,
}
pub struct ExecContextGuard(#[allow(dead_code)] Arc<ExecContext>);
impl ExecContextGuard {
pub fn new(exec_context: Arc<ExecContext>) -> Self {
Self(exec_context)
}
}
impl Drop for ExecContext {
fn drop(&mut self) {
*self.last_idle_instant.lock() = Some(Instant::now());
}
}
#[derive(Debug, Clone)]
pub enum UserAuthenticator {
None,
ClearText(Vec<u8>),
Md5WithSalt {
encrypted_password: Vec<u8>,
salt: [u8; 4],
},
OAuth(HashMap<String, String>),
}
#[derive(Debug, Deserialize)]
struct Jwks {
keys: Vec<Jwk>,
}
#[derive(Debug, Deserialize)]
struct Jwk {
kid: String, alg: String, n: String, e: String, }
async fn validate_jwt(
jwt: &str,
jwks_url: &str,
issuer: &str,
metadata: &HashMap<String, String>,
) -> Result<bool, BoxedError> {
let header = decode_header(jwt)?;
let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;
let kid = header.kid.ok_or("kid not found in jwt header")?;
let jwk = jwks
.keys
.into_iter()
.find(|k| k.kid == kid)
.ok_or("kid not found in jwks")?;
if Algorithm::from_str(&jwk.alg)? != header.alg {
return Err("alg in jwt header does not match with alg in jwk".into());
}
let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
let mut validation = Validation::new(header.alg);
validation.set_issuer(&[issuer]);
validation.set_required_spec_claims(&["exp", "iss"]);
let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;
if !metadata.iter().all(
|(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
) {
return Err("metadata in jwt does not match with metadata declared with user".into());
}
Ok(true)
}
impl UserAuthenticator {
pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> {
let success = match self {
UserAuthenticator::None => true,
UserAuthenticator::ClearText(text) => password == text,
UserAuthenticator::Md5WithSalt {
encrypted_password, ..
} => encrypted_password == password,
UserAuthenticator::OAuth(metadata) => {
let mut metadata = metadata.clone();
let jwks_url = metadata.remove("jwks_url").unwrap();
let issuer = metadata.remove("issuer").unwrap();
validate_jwt(
&String::from_utf8_lossy(password),
&jwks_url,
&issuer,
&metadata,
)
.await
.map_err(PsqlError::StartupError)?
}
};
if !success {
return Err(PsqlError::PasswordError);
}
Ok(())
}
}
pub async fn pg_serve(
addr: &str,
tcp_keepalive: TcpKeepalive,
session_mgr: Arc<impl SessionManager>,
tls_config: Option<TlsConfig>,
redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
shutdown: CancellationToken,
) -> Result<(), BoxedError> {
let listener = Listener::bind(addr).await?;
tracing::info!(addr, "server started");
let acceptor_runtime = BackgroundShutdownRuntime::from({
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder.worker_threads(1);
builder
.thread_name("rw-acceptor")
.enable_all()
.build()
.unwrap()
});
#[cfg(not(madsim))]
let worker_runtime = tokio::runtime::Handle::current();
#[cfg(madsim)]
let worker_runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
let session_mgr_clone = session_mgr.clone();
let f = async move {
loop {
let conn_ret = listener.accept(&tcp_keepalive).await;
match conn_ret {
Ok((stream, peer_addr)) => {
tracing::info!(%peer_addr, "accept connection");
worker_runtime.spawn(handle_connection(
stream,
session_mgr_clone.clone(),
tls_config.clone(),
Arc::new(peer_addr),
redact_sql_option_keywords.clone(),
));
}
Err(e) => {
tracing::error!(error = %e.as_report(), "failed to accept connection",);
}
}
}
};
acceptor_runtime.spawn(f);
shutdown.cancelled().await;
drop(acceptor_runtime);
session_mgr.shutdown().await;
Ok(())
}
pub async fn handle_connection<S, SM>(
stream: S,
session_mgr: Arc<SM>,
tls_config: Option<TlsConfig>,
peer_addr: AddressRef,
redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
) where
S: AsyncWrite + AsyncRead + Unpin,
SM: SessionManager,
{
let mut pg_proto = PgProtocol::new(
stream,
session_mgr,
tls_config,
peer_addr,
redact_sql_option_keywords,
);
loop {
let msg = match pg_proto.read_message().await {
Ok(msg) => msg,
Err(e) => {
tracing::error!(error = %e.as_report(), "error when reading message");
break;
}
};
tracing::trace!("Received message: {:?}", msg);
let ret = pg_proto.process(msg).await;
if ret {
break;
}
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use std::sync::Arc;
use std::time::Instant;
use bytes::Bytes;
use futures::stream::BoxStream;
use futures::StreamExt;
use risingwave_common::types::DataType;
use risingwave_common::util::tokio_util::sync::CancellationToken;
use risingwave_sqlparser::ast::Statement;
use tokio_postgres::NoTls;
use crate::error::PsqlResult;
use crate::pg_field_descriptor::PgFieldDescriptor;
use crate::pg_message::TransactionStatus;
use crate::pg_response::{PgResponse, RowSetResult, StatementType};
use crate::pg_server::{
pg_serve, BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager,
UserAuthenticator,
};
use crate::types;
use crate::types::Row;
struct MockSessionManager {}
struct MockSession {}
impl SessionManager for MockSessionManager {
type Session = MockSession;
fn create_dummy_session(
&self,
_database_id: u32,
_user_name: u32,
) -> Result<Arc<Self::Session>, BoxedError> {
unimplemented!()
}
fn connect(
&self,
_database: &str,
_user_name: &str,
_peer_addr: crate::net::AddressRef,
) -> Result<Arc<Self::Session>, Box<dyn Error + Send + Sync>> {
Ok(Arc::new(MockSession {}))
}
fn cancel_queries_in_session(&self, _session_id: SessionId) {
todo!()
}
fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
todo!()
}
fn end_session(&self, _session: &Self::Session) {}
}
impl Session for MockSession {
type Portal = String;
type PreparedStatement = String;
type ValuesStream = BoxStream<'static, RowSetResult>;
async fn run_one_query(
self: Arc<Self>,
_stmt: Statement,
_format: types::Format,
) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
Ok(PgResponse::builder(StatementType::SELECT)
.values(
futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
.boxed(),
vec![
PgFieldDescriptor::new("".to_string(), 1043, -1);
1
],
)
.into())
}
async fn parse(
self: Arc<Self>,
_sql: Option<Statement>,
_params_types: Vec<Option<DataType>>,
) -> Result<String, BoxedError> {
Ok(String::new())
}
fn bind(
self: Arc<Self>,
_prepare_statement: String,
_params: Vec<Option<Bytes>>,
_param_formats: Vec<types::Format>,
_result_formats: Vec<types::Format>,
) -> Result<String, BoxedError> {
Ok(String::new())
}
async fn execute(
self: Arc<Self>,
_portal: String,
) -> Result<PgResponse<BoxStream<'static, RowSetResult>>, BoxedError> {
Ok(PgResponse::builder(StatementType::SELECT)
.values(
futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])])
.boxed(),
vec![
PgFieldDescriptor::new("".to_string(), 1043, -1);
1
],
)
.into())
}
fn describe_statement(
self: Arc<Self>,
_statement: String,
) -> Result<(Vec<DataType>, Vec<PgFieldDescriptor>), BoxedError> {
Ok((
vec![],
vec![PgFieldDescriptor::new("".to_string(), 1043, -1)],
))
}
fn describe_portal(
self: Arc<Self>,
_portal: String,
) -> Result<Vec<PgFieldDescriptor>, BoxedError> {
Ok(vec![PgFieldDescriptor::new("".to_string(), 1043, -1)])
}
fn user_authenticator(&self) -> &UserAuthenticator {
&UserAuthenticator::None
}
fn id(&self) -> SessionId {
(0, 0)
}
fn set_config(&self, _key: &str, _value: String) -> Result<String, BoxedError> {
Ok("".to_string())
}
fn take_notices(self: Arc<Self>) -> Vec<String> {
vec![]
}
fn transaction_status(&self) -> TransactionStatus {
TransactionStatus::Idle
}
fn init_exec_context(&self, sql: Arc<str>) -> ExecContextGuard {
let exec_context = Arc::new(ExecContext {
running_sql: sql,
last_instant: Instant::now(),
last_idle_instant: Default::default(),
});
ExecContextGuard::new(exec_context)
}
fn check_idle_in_transaction_timeout(&self) -> PsqlResult<()> {
Ok(())
}
}
async fn do_test_query(bind_addr: impl Into<String>, pg_config: impl Into<String>) {
let bind_addr = bind_addr.into();
let pg_config = pg_config.into();
let session_mgr = MockSessionManager {};
tokio::spawn(async move {
pg_serve(
&bind_addr,
socket2::TcpKeepalive::new(),
Arc::new(session_mgr),
None,
None,
CancellationToken::new(), )
.await
});
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let (client, connection) = tokio_postgres::connect(&pg_config, NoTls).await.unwrap();
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let rows = client
.simple_query("SELECT ''")
.await
.expect("Error executing query");
assert_eq!(rows.len(), 2);
let rows = client
.query("SELECT ''", &[])
.await
.expect("Error executing query");
assert_eq!(rows.len(), 1);
}
#[tokio::test]
async fn test_query_tcp() {
do_test_query("127.0.0.1:10000", "host=localhost port=10000").await;
}
#[cfg(not(madsim))]
#[tokio::test]
async fn test_query_unix() {
let port: i16 = 10000;
let dir = tempfile::TempDir::new().unwrap();
let sock = dir.path().join(format!(".s.PGSQL.{port}"));
do_test_query(
format!("unix:{}", sock.to_str().unwrap()),
format!("host={} port={}", dir.path().to_str().unwrap(), port),
)
.await;
}
}