1use std::collections::HashMap;
16use std::fmt;
17
18use anyhow::anyhow;
19use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
20use postgres_openssl::MakeTlsConnector;
21use risingwave_common::bail;
22use risingwave_common::catalog::{ColumnDesc, ColumnId};
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use sea_schema::postgres::def::{ColumnType as SeaType, TableDef, TableInfo};
25use sea_schema::postgres::discovery::SchemaDiscovery;
26use serde_derive::Deserialize;
27use sqlx::PgPool;
28use sqlx::postgres::{PgConnectOptions, PgSslMode};
29use thiserror_ext::AsReport;
30use tokio_postgres::types::Kind as PgKind;
31use tokio_postgres::{Client as PgClient, NoTls};
32
33#[cfg(not(madsim))]
34use super::maybe_tls_connector::MaybeMakeTlsConnector;
35use crate::error::ConnectorResult;
36
37#[derive(Debug, Clone, PartialEq, Deserialize)]
38#[serde(rename_all = "lowercase")]
39pub enum SslMode {
40 #[serde(alias = "disable")]
41 Disabled,
42 #[serde(alias = "prefer")]
43 Preferred,
44 #[serde(alias = "require")]
45 Required,
46 #[serde(alias = "verify-ca")]
49 VerifyCa,
50 #[serde(alias = "verify-full")]
53 VerifyFull,
54}
55
56impl Default for SslMode {
57 fn default() -> Self {
58 Self::Preferred
59 }
60}
61
62pub struct PostgresExternalTable {
63 column_descs: Vec<ColumnDesc>,
64 pk_names: Vec<String>,
65}
66
67impl PostgresExternalTable {
68 async fn discover_schema(
69 username: &str,
70 password: &str,
71 host: &str,
72 port: u16,
73 database: &str,
74 schema: &str,
75 table: &str,
76 ssl_mode: &SslMode,
77 ssl_root_cert: &Option<String>,
78 ) -> ConnectorResult<TableDef> {
79 let mut options = PgConnectOptions::new()
80 .username(username)
81 .password(password)
82 .host(host)
83 .port(port)
84 .database(database)
85 .ssl_mode(match ssl_mode {
86 SslMode::Disabled => PgSslMode::Disable,
87 SslMode::Preferred => PgSslMode::Prefer,
88 SslMode::Required => PgSslMode::Require,
89 SslMode::VerifyCa => PgSslMode::VerifyCa,
90 SslMode::VerifyFull => PgSslMode::VerifyFull,
91 });
92
93 if *ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull {
94 if let Some(root_cert) = ssl_root_cert {
95 options = options.ssl_root_cert(root_cert.as_str());
96 }
97 }
98
99 let connection = PgPool::connect_with(options).await?;
100 let schema_discovery = SchemaDiscovery::new(connection, schema);
101 let empty_map = HashMap::new();
103 let table_schema = schema_discovery
104 .discover_table(
105 TableInfo {
106 name: table.to_owned(),
107 of_type: None,
108 },
109 &empty_map,
110 )
111 .await?;
112 Ok(table_schema)
113 }
114
115 pub async fn connect(
116 username: &str,
117 password: &str,
118 host: &str,
119 port: u16,
120 database: &str,
121 schema: &str,
122 table: &str,
123 ssl_mode: &SslMode,
124 ssl_root_cert: &Option<String>,
125 is_append_only: bool,
126 ) -> ConnectorResult<Self> {
127 tracing::debug!("connect to postgres external table");
128 let table_schema = Self::discover_schema(
129 username,
130 password,
131 host,
132 port,
133 database,
134 schema,
135 table,
136 ssl_mode,
137 ssl_root_cert,
138 )
139 .await?;
140 let mut column_descs = vec![];
141 for col in &table_schema.columns {
142 let rw_data_type = sea_type_to_rw_type(&col.col_type)?;
143 let column_desc = if let Some(ref default_expr) = col.default {
144 let val_text = default_expr
147 .0
148 .split("::")
149 .map(|s| s.trim_matches('\''))
150 .next()
151 .expect("default value expression");
152
153 match ScalarImpl::from_text(val_text, &rw_data_type) {
154 Ok(scalar) => ColumnDesc::named_with_default_value(
155 col.name.clone(),
156 ColumnId::placeholder(),
157 rw_data_type.clone(),
158 Some(scalar),
159 ),
160 Err(err) => {
161 tracing::warn!(error=%err.as_report(), "failed to parse postgres default value expression, only constant is supported");
162 ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
163 }
164 }
165 } else {
166 ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
167 };
168 column_descs.push(column_desc);
169 }
170
171 if !is_append_only && table_schema.primary_key_constraints.is_empty() {
172 return Err(anyhow!(
173 "Postgres table should define the primary key for non-append-only tables"
174 )
175 .into());
176 }
177 let mut pk_names = vec![];
178 table_schema.primary_key_constraints.iter().for_each(|pk| {
179 pk_names.extend(pk.columns.clone());
180 });
181
182 Ok(Self {
183 column_descs,
184 pk_names,
185 })
186 }
187
188 pub async fn type_mapping(
190 username: &str,
191 password: &str,
192 host: &str,
193 port: u16,
194 database: &str,
195 schema: &str,
196 table: &str,
197 ssl_mode: &SslMode,
198 ssl_root_cert: &Option<String>,
199 is_append_only: bool,
200 ) -> ConnectorResult<HashMap<String, tokio_postgres::types::Type>> {
201 tracing::debug!("connect to postgres external table to get type mapping");
202 let table_schema = Self::discover_schema(
203 username,
204 password,
205 host,
206 port,
207 database,
208 schema,
209 table,
210 ssl_mode,
211 ssl_root_cert,
212 )
213 .await?;
214 let mut column_name_to_pg_type = HashMap::new();
215 for col in &table_schema.columns {
216 let pg_type = sea_type_to_pg_type(&col.col_type)?;
217 column_name_to_pg_type.insert(col.name.clone(), pg_type);
218 }
219 if !is_append_only && table_schema.primary_key_constraints.is_empty() {
220 return Err(anyhow!(
221 "Postgres table should define the primary key for non-append-only tables"
222 )
223 .into());
224 }
225 Ok(column_name_to_pg_type)
226 }
227
228 pub fn column_descs(&self) -> &Vec<ColumnDesc> {
229 &self.column_descs
230 }
231
232 pub fn pk_names(&self) -> &Vec<String> {
233 &self.pk_names
234 }
235}
236
237impl fmt::Display for SslMode {
238 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239 f.write_str(match self {
240 SslMode::Disabled => "disabled",
241 SslMode::Preferred => "preferred",
242 SslMode::Required => "required",
243 SslMode::VerifyCa => "verify-ca",
244 SslMode::VerifyFull => "verify-full",
245 })
246 }
247}
248
249pub async fn create_pg_client(
250 user: &str,
251 password: &str,
252 host: &str,
253 port: &str,
254 database: &str,
255 ssl_mode: &SslMode,
256 ssl_root_cert: &Option<String>,
257) -> anyhow::Result<PgClient> {
258 let mut pg_config = tokio_postgres::Config::new();
259 pg_config
260 .user(user)
261 .password(password)
262 .host(host)
263 .port(port.parse::<u16>().unwrap())
264 .dbname(database);
265
266 let (_verify_ca, verify_hostname) = match ssl_mode {
267 SslMode::VerifyCa => (true, false),
268 SslMode::VerifyFull => (true, true),
269 _ => (false, false),
270 };
271
272 #[cfg(not(madsim))]
273 let connector = match ssl_mode {
274 SslMode::Disabled => {
275 pg_config.ssl_mode(tokio_postgres::config::SslMode::Disable);
276 MaybeMakeTlsConnector::NoTls(NoTls)
277 }
278 SslMode::Preferred => {
279 pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
280 match SslConnector::builder(SslMethod::tls()) {
281 Ok(mut builder) => {
282 builder.set_verify(SslVerifyMode::NONE);
284 MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
285 }
286 Err(e) => {
287 tracing::warn!(error = %e.as_report(), "SSL connector error");
288 MaybeMakeTlsConnector::NoTls(NoTls)
289 }
290 }
291 }
292 SslMode::Required => {
293 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
294 let mut builder = SslConnector::builder(SslMethod::tls())?;
295 builder.set_verify(SslVerifyMode::NONE);
297 MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
298 }
299
300 SslMode::VerifyCa | SslMode::VerifyFull => {
301 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
302 let mut builder = SslConnector::builder(SslMethod::tls())?;
303 if let Some(ssl_root_cert) = ssl_root_cert {
304 builder.set_ca_file(ssl_root_cert).map_err(|e| {
305 anyhow!(format!("bad ssl root cert error: {}", e.to_report_string()))
306 })?;
307 }
308 let mut connector = MakeTlsConnector::new(builder.build());
309 if !verify_hostname {
310 connector.set_callback(|config, _| {
311 config.set_verify_hostname(false);
312 Ok(())
313 });
314 }
315 MaybeMakeTlsConnector::Tls(connector)
316 }
317 };
318 #[cfg(madsim)]
319 let connector = NoTls;
320
321 let (client, connection) = pg_config.connect(connector).await?;
322
323 tokio::spawn(async move {
324 if let Err(e) = connection.await {
325 tracing::error!(error = %e.as_report(), "postgres connection error");
326 }
327 });
328
329 Ok(client)
330}
331
332pub fn sea_type_to_rw_type(col_type: &SeaType) -> ConnectorResult<DataType> {
334 let dtype = match col_type {
335 SeaType::SmallInt | SeaType::SmallSerial => DataType::Int16,
336 SeaType::Integer | SeaType::Serial => DataType::Int32,
337 SeaType::BigInt | SeaType::BigSerial => DataType::Int64,
338 SeaType::Money | SeaType::Decimal(_) | SeaType::Numeric(_) => DataType::Decimal,
339 SeaType::Real => DataType::Float32,
340 SeaType::DoublePrecision => DataType::Float64,
341 SeaType::Varchar(_) | SeaType::Char(_) | SeaType::Text => DataType::Varchar,
342 SeaType::Bytea => DataType::Bytea,
343 SeaType::Timestamp(_) => DataType::Timestamp,
344 SeaType::TimestampWithTimeZone(_) => DataType::Timestamptz,
345 SeaType::Date => DataType::Date,
346 SeaType::Time(_) | SeaType::TimeWithTimeZone(_) => DataType::Time,
347 SeaType::Interval(_) => DataType::Interval,
348 SeaType::Boolean => DataType::Boolean,
349 SeaType::Point => DataType::Struct(StructType::new(vec![
350 ("x", DataType::Float32),
351 ("y", DataType::Float32),
352 ])),
353 SeaType::Uuid => DataType::Varchar,
354 SeaType::Xml => DataType::Varchar,
355 SeaType::Json => DataType::Jsonb,
356 SeaType::JsonBinary => DataType::Jsonb,
357 SeaType::Array(def) => {
358 let item_type = match def.col_type.as_ref() {
359 Some(ty) => sea_type_to_rw_type(ty.as_ref())?,
360 None => {
361 return Err(anyhow!("ARRAY type missing element type").into());
362 }
363 };
364
365 DataType::List(Box::new(item_type))
366 }
367 SeaType::PgLsn => DataType::Int64,
368 SeaType::Cidr
369 | SeaType::Inet
370 | SeaType::MacAddr
371 | SeaType::MacAddr8
372 | SeaType::Int4Range
373 | SeaType::Int8Range
374 | SeaType::NumRange
375 | SeaType::TsRange
376 | SeaType::TsTzRange
377 | SeaType::DateRange
378 | SeaType::Enum(_) => DataType::Varchar,
379 SeaType::Line
380 | SeaType::Lseg
381 | SeaType::Box
382 | SeaType::Path
383 | SeaType::Polygon
384 | SeaType::Circle
385 | SeaType::Bit(_)
386 | SeaType::VarBit(_)
387 | SeaType::TsVector
388 | SeaType::TsQuery => {
389 bail!("{:?} type not supported", col_type);
390 }
391 SeaType::Unknown(name) => {
392 tracing::warn!("Unknown Postgres data type: {name}, map to varchar");
394 DataType::Varchar
395 }
396 };
397
398 Ok(dtype)
399}
400
401fn sea_type_to_pg_type(sea_type: &SeaType) -> ConnectorResult<tokio_postgres::types::Type> {
406 use tokio_postgres::types::Type as PgType;
407 match sea_type {
408 SeaType::SmallInt => Ok(PgType::INT2),
409 SeaType::Integer => Ok(PgType::INT4),
410 SeaType::BigInt => Ok(PgType::INT8),
411 SeaType::Decimal(_) => Ok(PgType::NUMERIC),
412 SeaType::Numeric(_) => Ok(PgType::NUMERIC),
413 SeaType::Real => Ok(PgType::FLOAT4),
414 SeaType::DoublePrecision => Ok(PgType::FLOAT8),
415 SeaType::Varchar(_) => Ok(PgType::VARCHAR),
416 SeaType::Char(_) => Ok(PgType::CHAR),
417 SeaType::Text => Ok(PgType::TEXT),
418 SeaType::Bytea => Ok(PgType::BYTEA),
419 SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP),
420 SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ),
421 SeaType::Date => Ok(PgType::DATE),
422 SeaType::Time(_) => Ok(PgType::TIME),
423 SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ),
424 SeaType::Interval(_) => Ok(PgType::INTERVAL),
425 SeaType::Boolean => Ok(PgType::BOOL),
426 SeaType::Point => Ok(PgType::POINT),
427 SeaType::Uuid => Ok(PgType::UUID),
428 SeaType::Json => Ok(PgType::JSON),
429 SeaType::JsonBinary => Ok(PgType::JSONB),
430 SeaType::Array(t) => {
431 let Some(t) = t.col_type.as_ref() else {
432 bail!("missing array type")
433 };
434 match t.as_ref() {
435 SeaType::SmallInt => Ok(PgType::INT2_ARRAY),
437 SeaType::Integer => Ok(PgType::INT4_ARRAY),
438 SeaType::BigInt => Ok(PgType::INT8_ARRAY),
439 SeaType::Decimal(_) => Ok(PgType::NUMERIC_ARRAY),
440 SeaType::Numeric(_) => Ok(PgType::NUMERIC_ARRAY),
441 SeaType::Real => Ok(PgType::FLOAT4_ARRAY),
442 SeaType::DoublePrecision => Ok(PgType::FLOAT8_ARRAY),
443 SeaType::Varchar(_) => Ok(PgType::VARCHAR_ARRAY),
444 SeaType::Char(_) => Ok(PgType::CHAR_ARRAY),
445 SeaType::Text => Ok(PgType::TEXT_ARRAY),
446 SeaType::Bytea => Ok(PgType::BYTEA_ARRAY),
447 SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP_ARRAY),
448 SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ_ARRAY),
449 SeaType::Date => Ok(PgType::DATE_ARRAY),
450 SeaType::Time(_) => Ok(PgType::TIME_ARRAY),
451 SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ_ARRAY),
452 SeaType::Interval(_) => Ok(PgType::INTERVAL_ARRAY),
453 SeaType::Boolean => Ok(PgType::BOOL_ARRAY),
454 SeaType::Point => Ok(PgType::POINT_ARRAY),
455 SeaType::Uuid => Ok(PgType::UUID_ARRAY),
456 SeaType::Json => Ok(PgType::JSON_ARRAY),
457 SeaType::JsonBinary => Ok(PgType::JSONB_ARRAY),
458 SeaType::Array(_) => bail!("nested array type is not supported"),
459 SeaType::Unknown(name) => {
460 Ok(PgType::new(
462 name.clone(),
463 0,
464 PgKind::Array(PgType::new(
465 name.clone(),
466 0,
467 PgKind::Enum(vec![]),
468 "".into(),
469 )),
470 "".into(),
471 ))
472 }
473 _ => bail!("unsupported array type: {:?}", t),
474 }
475 }
476 SeaType::Unknown(name) => {
477 Ok(PgType::new(
479 name.clone(),
480 0,
481 PgKind::Enum(vec![]),
482 "".into(),
483 ))
484 }
485 _ => bail!("unsupported type: {:?}", sea_type),
486 }
487}