1use std::collections::HashMap;
16use std::fmt;
17
18use anyhow::{Context, anyhow};
19use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
20use postgres_openssl::MakeTlsConnector;
21use risingwave_common::bail;
22use risingwave_common::catalog::{ColumnDesc, ColumnId};
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use sea_schema::postgres::def::{ColumnType as SeaType, TableDef, TableInfo};
25use sea_schema::postgres::discovery::SchemaDiscovery;
26use sea_schema::sea_query::{Alias, IntoIden};
27use serde::Deserialize;
28use sqlx::postgres::{PgConnectOptions, PgSslMode};
29use sqlx::{PgPool, Row};
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Kind as PgKind;
32use tokio_postgres::{Client as PgClient, NoTls};
33
34#[cfg(not(madsim))]
35use super::maybe_tls_connector::MaybeMakeTlsConnector;
36use crate::error::ConnectorResult;
37use crate::sink::postgres::TcpKeepaliveConfig;
38
39const DISCOVER_PRIMARY_KEY_QUERY: &str = r#"
42 SELECT a.attname as column_name
43 FROM pg_index i
44 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
45 WHERE i.indrelid = ($1 || '.' || $2)::regclass
46 AND i.indisprimary = true
47 ORDER BY array_position(i.indkey, a.attnum)
48"#;
49
50#[derive(Debug, Clone, PartialEq, Deserialize, Default)]
51#[serde(rename_all = "lowercase")]
52pub enum SslMode {
53 #[serde(alias = "disable")]
54 Disabled,
55 #[serde(alias = "prefer")]
56 #[default]
57 Preferred,
58 #[serde(alias = "require")]
59 Required,
60 #[serde(alias = "verify-ca")]
63 VerifyCa,
64 #[serde(alias = "verify-full")]
67 VerifyFull,
68}
69
70pub struct PostgresExternalTable {
71 column_descs: Vec<ColumnDesc>,
72 pk_names: Vec<String>,
73}
74
75impl PostgresExternalTable {
76 async fn discover_primary_key(
79 connection: &PgPool,
80 schema_name: &str,
81 table_name: &str,
82 ) -> ConnectorResult<Vec<String>> {
83 let rows = sqlx::query(DISCOVER_PRIMARY_KEY_QUERY)
84 .bind(schema_name)
85 .bind(table_name)
86 .fetch_all(connection)
87 .await
88 .context("Failed to discover primary key columns")?;
89
90 let pk_columns = rows
91 .into_iter()
92 .map(|row| row.get::<String, _>("column_name"))
93 .collect();
94
95 Ok(pk_columns)
96 }
97
98 async fn discover_pk_and_full_columns(
102 username: &str,
103 password: &str,
104 host: &str,
105 port: u16,
106 database: &str,
107 schema: &str,
108 table: &str,
109 ssl_mode: &SslMode,
110 ssl_root_cert: &Option<String>,
111 ) -> ConnectorResult<(Vec<sea_schema::postgres::def::ColumnInfo>, Vec<String>)> {
112 let mut options = PgConnectOptions::new()
113 .username(username)
114 .password(password)
115 .host(host)
116 .port(port)
117 .database(database)
118 .ssl_mode(match ssl_mode {
119 SslMode::Disabled => PgSslMode::Disable,
120 SslMode::Preferred => PgSslMode::Prefer,
121 SslMode::Required => PgSslMode::Require,
122 SslMode::VerifyCa => PgSslMode::VerifyCa,
123 SslMode::VerifyFull => PgSslMode::VerifyFull,
124 });
125
126 if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
127 && let Some(root_cert) = ssl_root_cert
128 {
129 options = options.ssl_root_cert(root_cert.as_str());
130 }
131
132 let connection = PgPool::connect_with(options).await?;
133
134 let schema_discovery = SchemaDiscovery::new(connection.clone(), schema);
136 let empty_map: HashMap<String, Vec<String>> = HashMap::new();
137 let columns = schema_discovery
138 .discover_columns(
139 Alias::new(schema).into_iden(),
140 Alias::new(table).into_iden(),
141 &empty_map,
142 )
143 .await?;
144
145 let pk_columns = Self::discover_primary_key(&connection, schema, table).await?;
147
148 Ok((columns, pk_columns))
149 }
150
151 async fn discover_schema(
152 username: &str,
153 password: &str,
154 host: &str,
155 port: u16,
156 database: &str,
157 schema: &str,
158 table: &str,
159 ssl_mode: &SslMode,
160 ssl_root_cert: &Option<String>,
161 ) -> ConnectorResult<TableDef> {
162 let mut options = PgConnectOptions::new()
163 .username(username)
164 .password(password)
165 .host(host)
166 .port(port)
167 .database(database)
168 .ssl_mode(match ssl_mode {
169 SslMode::Disabled => PgSslMode::Disable,
170 SslMode::Preferred => PgSslMode::Prefer,
171 SslMode::Required => PgSslMode::Require,
172 SslMode::VerifyCa => PgSslMode::VerifyCa,
173 SslMode::VerifyFull => PgSslMode::VerifyFull,
174 });
175
176 if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
177 && let Some(root_cert) = ssl_root_cert
178 {
179 options = options.ssl_root_cert(root_cert.as_str());
180 }
181
182 let connection = PgPool::connect_with(options).await?;
183 let schema_discovery = SchemaDiscovery::new(connection, schema);
184 let empty_map = HashMap::new();
186 let table_schema = schema_discovery
187 .discover_table(
188 TableInfo {
189 name: table.to_owned(),
190 of_type: None,
191 },
192 &empty_map,
193 )
194 .await?;
195 Ok(table_schema)
196 }
197
198 pub async fn connect(
199 username: &str,
200 password: &str,
201 host: &str,
202 port: u16,
203 database: &str,
204 schema: &str,
205 table: &str,
206 ssl_mode: &SslMode,
207 ssl_root_cert: &Option<String>,
208 is_append_only: bool,
209 ) -> ConnectorResult<Self> {
210 tracing::debug!("connect to postgres external table");
211
212 let (columns, pk_names) = Self::discover_pk_and_full_columns(
213 username,
214 password,
215 host,
216 port,
217 database,
218 schema,
219 table,
220 ssl_mode,
221 ssl_root_cert,
222 )
223 .await?;
224
225 let mut column_descs = vec![];
226 for col in &columns {
227 let rw_data_type = sea_type_to_rw_type(&col.col_type)?;
228 let column_desc = if let Some(ref default_expr) = col.default {
229 let val_text = default_expr
232 .0
233 .split("::")
234 .map(|s| s.trim_matches('\''))
235 .next()
236 .expect("default value expression");
237
238 match ScalarImpl::from_text(val_text, &rw_data_type) {
239 Ok(scalar) => ColumnDesc::named_with_default_value(
240 col.name.clone(),
241 ColumnId::placeholder(),
242 rw_data_type.clone(),
243 Some(scalar),
244 ),
245 Err(err) => {
246 tracing::warn!(error=%err.as_report(), "failed to parse postgres default value expression, only constant is supported");
247 ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
248 }
249 }
250 } else {
251 ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
252 };
253 column_descs.push(column_desc);
254 }
255
256 if !is_append_only && pk_names.is_empty() {
258 return Err(anyhow!(
259 "Postgres table should define the primary key for non-append-only tables"
260 )
261 .into());
262 }
263
264 Ok(Self {
265 column_descs,
266 pk_names,
267 })
268 }
269
270 pub async fn type_mapping(
272 username: &str,
273 password: &str,
274 host: &str,
275 port: u16,
276 database: &str,
277 schema: &str,
278 table: &str,
279 ssl_mode: &SslMode,
280 ssl_root_cert: &Option<String>,
281 is_append_only: bool,
282 ) -> ConnectorResult<HashMap<String, tokio_postgres::types::Type>> {
283 tracing::debug!("connect to postgres external table to get type mapping");
284 let table_schema = Self::discover_schema(
285 username,
286 password,
287 host,
288 port,
289 database,
290 schema,
291 table,
292 ssl_mode,
293 ssl_root_cert,
294 )
295 .await?;
296 let mut column_name_to_pg_type = HashMap::new();
297 for col in &table_schema.columns {
298 let pg_type = sea_type_to_pg_type(&col.col_type)?;
299 column_name_to_pg_type.insert(col.name.clone(), pg_type);
300 }
301 if !is_append_only && table_schema.primary_key_constraints.is_empty() {
302 return Err(anyhow!(
303 "Postgres table should define the primary key for non-append-only tables"
304 )
305 .into());
306 }
307 Ok(column_name_to_pg_type)
308 }
309
310 pub fn column_descs(&self) -> &Vec<ColumnDesc> {
311 &self.column_descs
312 }
313
314 pub fn pk_names(&self) -> &Vec<String> {
315 &self.pk_names
316 }
317}
318
319impl fmt::Display for SslMode {
320 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321 f.write_str(match self {
322 SslMode::Disabled => "disabled",
323 SslMode::Preferred => "preferred",
324 SslMode::Required => "required",
325 SslMode::VerifyCa => "verify-ca",
326 SslMode::VerifyFull => "verify-full",
327 })
328 }
329}
330
331impl std::str::FromStr for SslMode {
332 type Err = serde_json::Error;
333
334 fn from_str(s: &str) -> Result<Self, Self::Err> {
335 serde_json::from_value(serde_json::Value::String(s.to_owned()))
336 }
337}
338
339pub async fn create_pg_client(
340 user: &str,
341 password: &str,
342 host: &str,
343 port: &str,
344 database: &str,
345 ssl_mode: &SslMode,
346 ssl_root_cert: &Option<String>,
347 tcp_keepalive: Option<TcpKeepaliveConfig>,
348) -> anyhow::Result<PgClient> {
349 let mut pg_config = tokio_postgres::Config::new();
350 pg_config
351 .user(user)
352 .password(password)
353 .host(host)
354 .port(port.parse::<u16>().unwrap())
355 .dbname(database);
356
357 if let Some(keepalive) = tcp_keepalive {
359 pg_config.keepalives(true);
360 pg_config.keepalives_idle(std::time::Duration::from_secs(
361 keepalive.tcp_keepalive_idle as u64,
362 ));
363 #[cfg(not(target_os = "windows"))]
364 {
365 pg_config.keepalives_interval(std::time::Duration::from_secs(
366 keepalive.tcp_keepalive_interval as u64,
367 ));
368 pg_config.keepalives_retries(keepalive.tcp_keepalive_count);
369 }
370 tracing::info!(
371 "TCP keepalive enabled: idle={}s, interval={}s, retries={}",
372 keepalive.tcp_keepalive_idle,
373 keepalive.tcp_keepalive_interval,
374 keepalive.tcp_keepalive_count
375 );
376 }
377
378 let (_verify_ca, verify_hostname) = match ssl_mode {
379 SslMode::VerifyCa => (true, false),
380 SslMode::VerifyFull => (true, true),
381 _ => (false, false),
382 };
383
384 #[cfg(not(madsim))]
385 let connector = match ssl_mode {
386 SslMode::Disabled => {
387 pg_config.ssl_mode(tokio_postgres::config::SslMode::Disable);
388 MaybeMakeTlsConnector::NoTls(NoTls)
389 }
390 SslMode::Preferred => {
391 pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
392 match SslConnector::builder(SslMethod::tls()) {
393 Ok(mut builder) => {
394 builder.set_verify(SslVerifyMode::NONE);
396 MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
397 }
398 Err(e) => {
399 tracing::warn!(error = %e.as_report(), "SSL connector error");
400 MaybeMakeTlsConnector::NoTls(NoTls)
401 }
402 }
403 }
404 SslMode::Required => {
405 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
406 let mut builder = SslConnector::builder(SslMethod::tls())?;
407 builder.set_verify(SslVerifyMode::NONE);
409 MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
410 }
411
412 SslMode::VerifyCa | SslMode::VerifyFull => {
413 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
414 let mut builder = SslConnector::builder(SslMethod::tls())?;
415 if let Some(ssl_root_cert) = ssl_root_cert {
416 builder.set_ca_file(ssl_root_cert).map_err(|e| {
417 anyhow!(format!("bad ssl root cert error: {}", e.to_report_string()))
418 })?;
419 }
420 let mut connector = MakeTlsConnector::new(builder.build());
421 if !verify_hostname {
422 connector.set_callback(|config, _| {
423 config.set_verify_hostname(false);
424 Ok(())
425 });
426 }
427 MaybeMakeTlsConnector::Tls(connector)
428 }
429 };
430 #[cfg(madsim)]
431 let connector = NoTls;
432
433 let (client, connection) = pg_config.connect(connector).await?;
434
435 tokio::spawn(async move {
436 if let Err(e) = connection.await {
437 tracing::error!(error = %e.as_report(), "postgres connection error");
438 }
439 });
440
441 Ok(client)
442}
443
444pub fn sea_type_to_rw_type(col_type: &SeaType) -> ConnectorResult<DataType> {
446 let dtype = match col_type {
447 SeaType::SmallInt | SeaType::SmallSerial => DataType::Int16,
448 SeaType::Integer | SeaType::Serial => DataType::Int32,
449 SeaType::BigInt | SeaType::BigSerial => DataType::Int64,
450 SeaType::Money | SeaType::Decimal(_) | SeaType::Numeric(_) => DataType::Decimal,
451 SeaType::Real => DataType::Float32,
452 SeaType::DoublePrecision => DataType::Float64,
453 SeaType::Varchar(_) | SeaType::Char(_) | SeaType::Text => DataType::Varchar,
454 SeaType::Bytea => DataType::Bytea,
455 SeaType::Timestamp(_) => DataType::Timestamp,
456 SeaType::TimestampWithTimeZone(_) => DataType::Timestamptz,
457 SeaType::Date => DataType::Date,
458 SeaType::Time(_) | SeaType::TimeWithTimeZone(_) => DataType::Time,
459 SeaType::Interval(_) => DataType::Interval,
460 SeaType::Boolean => DataType::Boolean,
461 SeaType::Point => DataType::Struct(StructType::new(vec![
462 ("x", DataType::Float32),
463 ("y", DataType::Float32),
464 ])),
465 SeaType::Uuid => DataType::Varchar,
466 SeaType::Xml => DataType::Varchar,
467 SeaType::Json => DataType::Jsonb,
468 SeaType::JsonBinary => DataType::Jsonb,
469 SeaType::Array(def) => {
470 let item_type = match def.col_type.as_ref() {
471 Some(ty) => sea_type_to_rw_type(ty.as_ref())?,
472 None => {
473 return Err(anyhow!("ARRAY type missing element type").into());
474 }
475 };
476
477 DataType::list(item_type)
478 }
479 SeaType::PgLsn => DataType::Int64,
480 SeaType::Cidr
481 | SeaType::Inet
482 | SeaType::MacAddr
483 | SeaType::MacAddr8
484 | SeaType::Int4Range
485 | SeaType::Int8Range
486 | SeaType::NumRange
487 | SeaType::TsRange
488 | SeaType::TsTzRange
489 | SeaType::DateRange
490 | SeaType::Enum(_) => DataType::Varchar,
491 SeaType::Line
492 | SeaType::Lseg
493 | SeaType::Box
494 | SeaType::Path
495 | SeaType::Polygon
496 | SeaType::Circle
497 | SeaType::Bit(_)
498 | SeaType::VarBit(_)
499 | SeaType::TsVector
500 | SeaType::TsQuery => {
501 bail!("{:?} type not supported", col_type);
502 }
503 SeaType::Unknown(name) => {
504 tracing::warn!("Unknown Postgres data type: {name}, map to varchar");
506 DataType::Varchar
507 }
508 };
509
510 Ok(dtype)
511}
512
513fn sea_type_to_pg_type(sea_type: &SeaType) -> ConnectorResult<tokio_postgres::types::Type> {
518 use tokio_postgres::types::Type as PgType;
519 match sea_type {
520 SeaType::SmallInt => Ok(PgType::INT2),
521 SeaType::Integer => Ok(PgType::INT4),
522 SeaType::BigInt => Ok(PgType::INT8),
523 SeaType::Decimal(_) => Ok(PgType::NUMERIC),
524 SeaType::Numeric(_) => Ok(PgType::NUMERIC),
525 SeaType::Real => Ok(PgType::FLOAT4),
526 SeaType::DoublePrecision => Ok(PgType::FLOAT8),
527 SeaType::Varchar(_) => Ok(PgType::VARCHAR),
528 SeaType::Char(_) => Ok(PgType::CHAR),
529 SeaType::Text => Ok(PgType::TEXT),
530 SeaType::Bytea => Ok(PgType::BYTEA),
531 SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP),
532 SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ),
533 SeaType::Date => Ok(PgType::DATE),
534 SeaType::Time(_) => Ok(PgType::TIME),
535 SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ),
536 SeaType::Interval(_) => Ok(PgType::INTERVAL),
537 SeaType::Boolean => Ok(PgType::BOOL),
538 SeaType::Point => Ok(PgType::POINT),
539 SeaType::Uuid => Ok(PgType::UUID),
540 SeaType::Json => Ok(PgType::JSON),
541 SeaType::JsonBinary => Ok(PgType::JSONB),
542 SeaType::Array(t) => {
543 let Some(t) = t.col_type.as_ref() else {
544 bail!("missing array type")
545 };
546 match t.as_ref() {
547 SeaType::SmallInt => Ok(PgType::INT2_ARRAY),
549 SeaType::Integer => Ok(PgType::INT4_ARRAY),
550 SeaType::BigInt => Ok(PgType::INT8_ARRAY),
551 SeaType::Decimal(_) => Ok(PgType::NUMERIC_ARRAY),
552 SeaType::Numeric(_) => Ok(PgType::NUMERIC_ARRAY),
553 SeaType::Real => Ok(PgType::FLOAT4_ARRAY),
554 SeaType::DoublePrecision => Ok(PgType::FLOAT8_ARRAY),
555 SeaType::Varchar(_) => Ok(PgType::VARCHAR_ARRAY),
556 SeaType::Char(_) => Ok(PgType::CHAR_ARRAY),
557 SeaType::Text => Ok(PgType::TEXT_ARRAY),
558 SeaType::Bytea => Ok(PgType::BYTEA_ARRAY),
559 SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP_ARRAY),
560 SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ_ARRAY),
561 SeaType::Date => Ok(PgType::DATE_ARRAY),
562 SeaType::Time(_) => Ok(PgType::TIME_ARRAY),
563 SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ_ARRAY),
564 SeaType::Interval(_) => Ok(PgType::INTERVAL_ARRAY),
565 SeaType::Boolean => Ok(PgType::BOOL_ARRAY),
566 SeaType::Point => Ok(PgType::POINT_ARRAY),
567 SeaType::Uuid => Ok(PgType::UUID_ARRAY),
568 SeaType::Json => Ok(PgType::JSON_ARRAY),
569 SeaType::JsonBinary => Ok(PgType::JSONB_ARRAY),
570 SeaType::Array(_) => bail!("nested array type is not supported"),
571 SeaType::Unknown(name) => {
572 Ok(PgType::new(
574 name.clone(),
575 0,
576 PgKind::Array(PgType::new(
577 name.clone(),
578 0,
579 PgKind::Enum(vec![]),
580 "".into(),
581 )),
582 "".into(),
583 ))
584 }
585 _ => bail!("unsupported array type: {:?}", t),
586 }
587 }
588 SeaType::Unknown(name) => {
589 Ok(PgType::new(
591 name.clone(),
592 0,
593 PgKind::Enum(vec![]),
594 "".into(),
595 ))
596 }
597 _ => bail!("unsupported type: {:?}", sea_type),
598 }
599}