1use std::collections::HashMap;
16use std::fmt;
17
18use anyhow::{Context, anyhow};
19use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
20use postgres_openssl::MakeTlsConnector;
21use risingwave_common::bail;
22use risingwave_common::catalog::{ColumnDesc, ColumnId};
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use sea_schema::postgres::def::{ColumnType as SeaType, TableDef, TableInfo};
25use sea_schema::postgres::discovery::SchemaDiscovery;
26use sea_schema::sea_query::{Alias, IntoIden};
27use serde::Deserialize;
28use sqlx::postgres::{PgConnectOptions, PgSslMode};
29use sqlx::{PgPool, Row};
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Kind as PgKind;
32use tokio_postgres::{Client as PgClient, NoTls};
33
34#[cfg(not(madsim))]
35use super::maybe_tls_connector::MaybeMakeTlsConnector;
36use crate::error::ConnectorResult;
37
38const DISCOVER_PRIMARY_KEY_QUERY: &str = r#"
41 SELECT a.attname as column_name
42 FROM pg_index i
43 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
44 WHERE i.indrelid = ($1 || '.' || $2)::regclass
45 AND i.indisprimary = true
46 ORDER BY array_position(i.indkey, a.attnum)
47"#;
48
49#[derive(Debug, Clone, PartialEq, Deserialize, Default)]
50#[serde(rename_all = "lowercase")]
51pub enum SslMode {
52 #[serde(alias = "disable")]
53 Disabled,
54 #[serde(alias = "prefer")]
55 #[default]
56 Preferred,
57 #[serde(alias = "require")]
58 Required,
59 #[serde(alias = "verify-ca")]
62 VerifyCa,
63 #[serde(alias = "verify-full")]
66 VerifyFull,
67}
68
69pub struct PostgresExternalTable {
70 column_descs: Vec<ColumnDesc>,
71 pk_names: Vec<String>,
72}
73
74impl PostgresExternalTable {
75 async fn discover_primary_key(
78 connection: &PgPool,
79 schema_name: &str,
80 table_name: &str,
81 ) -> ConnectorResult<Vec<String>> {
82 let rows = sqlx::query(DISCOVER_PRIMARY_KEY_QUERY)
83 .bind(schema_name)
84 .bind(table_name)
85 .fetch_all(connection)
86 .await
87 .context("Failed to discover primary key columns")?;
88
89 let pk_columns = rows
90 .into_iter()
91 .map(|row| row.get::<String, _>("column_name"))
92 .collect();
93
94 Ok(pk_columns)
95 }
96
97 async fn discover_pk_and_full_columns(
101 username: &str,
102 password: &str,
103 host: &str,
104 port: u16,
105 database: &str,
106 schema: &str,
107 table: &str,
108 ssl_mode: &SslMode,
109 ssl_root_cert: &Option<String>,
110 ) -> ConnectorResult<(Vec<sea_schema::postgres::def::ColumnInfo>, Vec<String>)> {
111 let mut options = PgConnectOptions::new()
112 .username(username)
113 .password(password)
114 .host(host)
115 .port(port)
116 .database(database)
117 .ssl_mode(match ssl_mode {
118 SslMode::Disabled => PgSslMode::Disable,
119 SslMode::Preferred => PgSslMode::Prefer,
120 SslMode::Required => PgSslMode::Require,
121 SslMode::VerifyCa => PgSslMode::VerifyCa,
122 SslMode::VerifyFull => PgSslMode::VerifyFull,
123 });
124
125 if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
126 && let Some(root_cert) = ssl_root_cert
127 {
128 options = options.ssl_root_cert(root_cert.as_str());
129 }
130
131 let connection = PgPool::connect_with(options).await?;
132
133 let schema_discovery = SchemaDiscovery::new(connection.clone(), schema);
135 let empty_map: HashMap<String, Vec<String>> = HashMap::new();
136 let columns = schema_discovery
137 .discover_columns(
138 Alias::new(schema).into_iden(),
139 Alias::new(table).into_iden(),
140 &empty_map,
141 )
142 .await?;
143
144 let pk_columns = Self::discover_primary_key(&connection, schema, table).await?;
146
147 Ok((columns, pk_columns))
148 }
149
150 async fn discover_schema(
151 username: &str,
152 password: &str,
153 host: &str,
154 port: u16,
155 database: &str,
156 schema: &str,
157 table: &str,
158 ssl_mode: &SslMode,
159 ssl_root_cert: &Option<String>,
160 ) -> ConnectorResult<TableDef> {
161 let mut options = PgConnectOptions::new()
162 .username(username)
163 .password(password)
164 .host(host)
165 .port(port)
166 .database(database)
167 .ssl_mode(match ssl_mode {
168 SslMode::Disabled => PgSslMode::Disable,
169 SslMode::Preferred => PgSslMode::Prefer,
170 SslMode::Required => PgSslMode::Require,
171 SslMode::VerifyCa => PgSslMode::VerifyCa,
172 SslMode::VerifyFull => PgSslMode::VerifyFull,
173 });
174
175 if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
176 && let Some(root_cert) = ssl_root_cert
177 {
178 options = options.ssl_root_cert(root_cert.as_str());
179 }
180
181 let connection = PgPool::connect_with(options).await?;
182 let schema_discovery = SchemaDiscovery::new(connection, schema);
183 let empty_map = HashMap::new();
185 let table_schema = schema_discovery
186 .discover_table(
187 TableInfo {
188 name: table.to_owned(),
189 of_type: None,
190 },
191 &empty_map,
192 )
193 .await?;
194 Ok(table_schema)
195 }
196
197 pub async fn connect(
198 username: &str,
199 password: &str,
200 host: &str,
201 port: u16,
202 database: &str,
203 schema: &str,
204 table: &str,
205 ssl_mode: &SslMode,
206 ssl_root_cert: &Option<String>,
207 is_append_only: bool,
208 ) -> ConnectorResult<Self> {
209 tracing::debug!("connect to postgres external table");
210
211 let (columns, pk_names) = Self::discover_pk_and_full_columns(
212 username,
213 password,
214 host,
215 port,
216 database,
217 schema,
218 table,
219 ssl_mode,
220 ssl_root_cert,
221 )
222 .await?;
223
224 let mut column_descs = vec![];
225 for col in &columns {
226 let rw_data_type = sea_type_to_rw_type(&col.col_type)?;
227 let column_desc = if let Some(ref default_expr) = col.default {
228 let val_text = default_expr
231 .0
232 .split("::")
233 .map(|s| s.trim_matches('\''))
234 .next()
235 .expect("default value expression");
236
237 match ScalarImpl::from_text(val_text, &rw_data_type) {
238 Ok(scalar) => ColumnDesc::named_with_default_value(
239 col.name.clone(),
240 ColumnId::placeholder(),
241 rw_data_type.clone(),
242 Some(scalar),
243 ),
244 Err(err) => {
245 tracing::warn!(error=%err.as_report(), "failed to parse postgres default value expression, only constant is supported");
246 ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
247 }
248 }
249 } else {
250 ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
251 };
252 column_descs.push(column_desc);
253 }
254
255 if !is_append_only && pk_names.is_empty() {
257 return Err(anyhow!(
258 "Postgres table should define the primary key for non-append-only tables"
259 )
260 .into());
261 }
262
263 Ok(Self {
264 column_descs,
265 pk_names,
266 })
267 }
268
269 pub async fn type_mapping(
271 username: &str,
272 password: &str,
273 host: &str,
274 port: u16,
275 database: &str,
276 schema: &str,
277 table: &str,
278 ssl_mode: &SslMode,
279 ssl_root_cert: &Option<String>,
280 is_append_only: bool,
281 ) -> ConnectorResult<HashMap<String, tokio_postgres::types::Type>> {
282 tracing::debug!("connect to postgres external table to get type mapping");
283 let table_schema = Self::discover_schema(
284 username,
285 password,
286 host,
287 port,
288 database,
289 schema,
290 table,
291 ssl_mode,
292 ssl_root_cert,
293 )
294 .await?;
295 let mut column_name_to_pg_type = HashMap::new();
296 for col in &table_schema.columns {
297 let pg_type = sea_type_to_pg_type(&col.col_type)?;
298 column_name_to_pg_type.insert(col.name.clone(), pg_type);
299 }
300 if !is_append_only && table_schema.primary_key_constraints.is_empty() {
301 return Err(anyhow!(
302 "Postgres table should define the primary key for non-append-only tables"
303 )
304 .into());
305 }
306 Ok(column_name_to_pg_type)
307 }
308
309 pub fn column_descs(&self) -> &Vec<ColumnDesc> {
310 &self.column_descs
311 }
312
313 pub fn pk_names(&self) -> &Vec<String> {
314 &self.pk_names
315 }
316}
317
318impl fmt::Display for SslMode {
319 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320 f.write_str(match self {
321 SslMode::Disabled => "disabled",
322 SslMode::Preferred => "preferred",
323 SslMode::Required => "required",
324 SslMode::VerifyCa => "verify-ca",
325 SslMode::VerifyFull => "verify-full",
326 })
327 }
328}
329
330pub async fn create_pg_client(
331 user: &str,
332 password: &str,
333 host: &str,
334 port: &str,
335 database: &str,
336 ssl_mode: &SslMode,
337 ssl_root_cert: &Option<String>,
338) -> anyhow::Result<PgClient> {
339 let mut pg_config = tokio_postgres::Config::new();
340 pg_config
341 .user(user)
342 .password(password)
343 .host(host)
344 .port(port.parse::<u16>().unwrap())
345 .dbname(database);
346
347 let (_verify_ca, verify_hostname) = match ssl_mode {
348 SslMode::VerifyCa => (true, false),
349 SslMode::VerifyFull => (true, true),
350 _ => (false, false),
351 };
352
353 #[cfg(not(madsim))]
354 let connector = match ssl_mode {
355 SslMode::Disabled => {
356 pg_config.ssl_mode(tokio_postgres::config::SslMode::Disable);
357 MaybeMakeTlsConnector::NoTls(NoTls)
358 }
359 SslMode::Preferred => {
360 pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
361 match SslConnector::builder(SslMethod::tls()) {
362 Ok(mut builder) => {
363 builder.set_verify(SslVerifyMode::NONE);
365 MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
366 }
367 Err(e) => {
368 tracing::warn!(error = %e.as_report(), "SSL connector error");
369 MaybeMakeTlsConnector::NoTls(NoTls)
370 }
371 }
372 }
373 SslMode::Required => {
374 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
375 let mut builder = SslConnector::builder(SslMethod::tls())?;
376 builder.set_verify(SslVerifyMode::NONE);
378 MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
379 }
380
381 SslMode::VerifyCa | SslMode::VerifyFull => {
382 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
383 let mut builder = SslConnector::builder(SslMethod::tls())?;
384 if let Some(ssl_root_cert) = ssl_root_cert {
385 builder.set_ca_file(ssl_root_cert).map_err(|e| {
386 anyhow!(format!("bad ssl root cert error: {}", e.to_report_string()))
387 })?;
388 }
389 let mut connector = MakeTlsConnector::new(builder.build());
390 if !verify_hostname {
391 connector.set_callback(|config, _| {
392 config.set_verify_hostname(false);
393 Ok(())
394 });
395 }
396 MaybeMakeTlsConnector::Tls(connector)
397 }
398 };
399 #[cfg(madsim)]
400 let connector = NoTls;
401
402 let (client, connection) = pg_config.connect(connector).await?;
403
404 tokio::spawn(async move {
405 if let Err(e) = connection.await {
406 tracing::error!(error = %e.as_report(), "postgres connection error");
407 }
408 });
409
410 Ok(client)
411}
412
413pub fn sea_type_to_rw_type(col_type: &SeaType) -> ConnectorResult<DataType> {
415 let dtype = match col_type {
416 SeaType::SmallInt | SeaType::SmallSerial => DataType::Int16,
417 SeaType::Integer | SeaType::Serial => DataType::Int32,
418 SeaType::BigInt | SeaType::BigSerial => DataType::Int64,
419 SeaType::Money | SeaType::Decimal(_) | SeaType::Numeric(_) => DataType::Decimal,
420 SeaType::Real => DataType::Float32,
421 SeaType::DoublePrecision => DataType::Float64,
422 SeaType::Varchar(_) | SeaType::Char(_) | SeaType::Text => DataType::Varchar,
423 SeaType::Bytea => DataType::Bytea,
424 SeaType::Timestamp(_) => DataType::Timestamp,
425 SeaType::TimestampWithTimeZone(_) => DataType::Timestamptz,
426 SeaType::Date => DataType::Date,
427 SeaType::Time(_) | SeaType::TimeWithTimeZone(_) => DataType::Time,
428 SeaType::Interval(_) => DataType::Interval,
429 SeaType::Boolean => DataType::Boolean,
430 SeaType::Point => DataType::Struct(StructType::new(vec![
431 ("x", DataType::Float32),
432 ("y", DataType::Float32),
433 ])),
434 SeaType::Uuid => DataType::Varchar,
435 SeaType::Xml => DataType::Varchar,
436 SeaType::Json => DataType::Jsonb,
437 SeaType::JsonBinary => DataType::Jsonb,
438 SeaType::Array(def) => {
439 let item_type = match def.col_type.as_ref() {
440 Some(ty) => sea_type_to_rw_type(ty.as_ref())?,
441 None => {
442 return Err(anyhow!("ARRAY type missing element type").into());
443 }
444 };
445
446 DataType::list(item_type)
447 }
448 SeaType::PgLsn => DataType::Int64,
449 SeaType::Cidr
450 | SeaType::Inet
451 | SeaType::MacAddr
452 | SeaType::MacAddr8
453 | SeaType::Int4Range
454 | SeaType::Int8Range
455 | SeaType::NumRange
456 | SeaType::TsRange
457 | SeaType::TsTzRange
458 | SeaType::DateRange
459 | SeaType::Enum(_) => DataType::Varchar,
460 SeaType::Line
461 | SeaType::Lseg
462 | SeaType::Box
463 | SeaType::Path
464 | SeaType::Polygon
465 | SeaType::Circle
466 | SeaType::Bit(_)
467 | SeaType::VarBit(_)
468 | SeaType::TsVector
469 | SeaType::TsQuery => {
470 bail!("{:?} type not supported", col_type);
471 }
472 SeaType::Unknown(name) => {
473 tracing::warn!("Unknown Postgres data type: {name}, map to varchar");
475 DataType::Varchar
476 }
477 };
478
479 Ok(dtype)
480}
481
482fn sea_type_to_pg_type(sea_type: &SeaType) -> ConnectorResult<tokio_postgres::types::Type> {
487 use tokio_postgres::types::Type as PgType;
488 match sea_type {
489 SeaType::SmallInt => Ok(PgType::INT2),
490 SeaType::Integer => Ok(PgType::INT4),
491 SeaType::BigInt => Ok(PgType::INT8),
492 SeaType::Decimal(_) => Ok(PgType::NUMERIC),
493 SeaType::Numeric(_) => Ok(PgType::NUMERIC),
494 SeaType::Real => Ok(PgType::FLOAT4),
495 SeaType::DoublePrecision => Ok(PgType::FLOAT8),
496 SeaType::Varchar(_) => Ok(PgType::VARCHAR),
497 SeaType::Char(_) => Ok(PgType::CHAR),
498 SeaType::Text => Ok(PgType::TEXT),
499 SeaType::Bytea => Ok(PgType::BYTEA),
500 SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP),
501 SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ),
502 SeaType::Date => Ok(PgType::DATE),
503 SeaType::Time(_) => Ok(PgType::TIME),
504 SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ),
505 SeaType::Interval(_) => Ok(PgType::INTERVAL),
506 SeaType::Boolean => Ok(PgType::BOOL),
507 SeaType::Point => Ok(PgType::POINT),
508 SeaType::Uuid => Ok(PgType::UUID),
509 SeaType::Json => Ok(PgType::JSON),
510 SeaType::JsonBinary => Ok(PgType::JSONB),
511 SeaType::Array(t) => {
512 let Some(t) = t.col_type.as_ref() else {
513 bail!("missing array type")
514 };
515 match t.as_ref() {
516 SeaType::SmallInt => Ok(PgType::INT2_ARRAY),
518 SeaType::Integer => Ok(PgType::INT4_ARRAY),
519 SeaType::BigInt => Ok(PgType::INT8_ARRAY),
520 SeaType::Decimal(_) => Ok(PgType::NUMERIC_ARRAY),
521 SeaType::Numeric(_) => Ok(PgType::NUMERIC_ARRAY),
522 SeaType::Real => Ok(PgType::FLOAT4_ARRAY),
523 SeaType::DoublePrecision => Ok(PgType::FLOAT8_ARRAY),
524 SeaType::Varchar(_) => Ok(PgType::VARCHAR_ARRAY),
525 SeaType::Char(_) => Ok(PgType::CHAR_ARRAY),
526 SeaType::Text => Ok(PgType::TEXT_ARRAY),
527 SeaType::Bytea => Ok(PgType::BYTEA_ARRAY),
528 SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP_ARRAY),
529 SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ_ARRAY),
530 SeaType::Date => Ok(PgType::DATE_ARRAY),
531 SeaType::Time(_) => Ok(PgType::TIME_ARRAY),
532 SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ_ARRAY),
533 SeaType::Interval(_) => Ok(PgType::INTERVAL_ARRAY),
534 SeaType::Boolean => Ok(PgType::BOOL_ARRAY),
535 SeaType::Point => Ok(PgType::POINT_ARRAY),
536 SeaType::Uuid => Ok(PgType::UUID_ARRAY),
537 SeaType::Json => Ok(PgType::JSON_ARRAY),
538 SeaType::JsonBinary => Ok(PgType::JSONB_ARRAY),
539 SeaType::Array(_) => bail!("nested array type is not supported"),
540 SeaType::Unknown(name) => {
541 Ok(PgType::new(
543 name.clone(),
544 0,
545 PgKind::Array(PgType::new(
546 name.clone(),
547 0,
548 PgKind::Enum(vec![]),
549 "".into(),
550 )),
551 "".into(),
552 ))
553 }
554 _ => bail!("unsupported array type: {:?}", t),
555 }
556 }
557 SeaType::Unknown(name) => {
558 Ok(PgType::new(
560 name.clone(),
561 0,
562 PgKind::Enum(vec![]),
563 "".into(),
564 ))
565 }
566 _ => bail!("unsupported type: {:?}", sea_type),
567 }
568}