1use std::collections::HashMap;
16use std::fmt;
17
18use anyhow::{Context, anyhow};
19use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
20use postgres_openssl::MakeTlsConnector;
21use risingwave_common::bail;
22use risingwave_common::catalog::{ColumnDesc, ColumnId};
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use sea_schema::postgres::def::{ColumnType as SeaType, TableDef, TableInfo};
25use sea_schema::postgres::discovery::SchemaDiscovery;
26use sea_schema::sea_query::{Alias, IntoIden};
27use serde::Deserialize;
28use sqlx::postgres::{PgConnectOptions, PgSslMode};
29use sqlx::{PgPool, Row};
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Kind as PgKind;
32use tokio_postgres::{Client as PgClient, NoTls};
33
34#[cfg(not(madsim))]
35use super::maybe_tls_connector::MaybeMakeTlsConnector;
36use crate::error::ConnectorResult;
37
38const DISCOVER_PRIMARY_KEY_QUERY: &str = r#"
41 SELECT a.attname as column_name
42 FROM pg_index i
43 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
44 WHERE i.indrelid = ($1 || '.' || $2)::regclass
45 AND i.indisprimary = true
46 ORDER BY array_position(i.indkey, a.attnum)
47"#;
48
49#[derive(Debug, Clone, PartialEq, Deserialize)]
50#[serde(rename_all = "lowercase")]
51pub enum SslMode {
52 #[serde(alias = "disable")]
53 Disabled,
54 #[serde(alias = "prefer")]
55 Preferred,
56 #[serde(alias = "require")]
57 Required,
58 #[serde(alias = "verify-ca")]
61 VerifyCa,
62 #[serde(alias = "verify-full")]
65 VerifyFull,
66}
67
68impl Default for SslMode {
69 fn default() -> Self {
70 Self::Preferred
71 }
72}
73
74pub struct PostgresExternalTable {
75 column_descs: Vec<ColumnDesc>,
76 pk_names: Vec<String>,
77}
78
79impl PostgresExternalTable {
80 async fn discover_primary_key(
83 connection: &PgPool,
84 schema_name: &str,
85 table_name: &str,
86 ) -> ConnectorResult<Vec<String>> {
87 let rows = sqlx::query(DISCOVER_PRIMARY_KEY_QUERY)
88 .bind(schema_name)
89 .bind(table_name)
90 .fetch_all(connection)
91 .await
92 .context("Failed to discover primary key columns")?;
93
94 let pk_columns = rows
95 .into_iter()
96 .map(|row| row.get::<String, _>("column_name"))
97 .collect();
98
99 Ok(pk_columns)
100 }
101
102 async fn discover_pk_and_full_columns(
106 username: &str,
107 password: &str,
108 host: &str,
109 port: u16,
110 database: &str,
111 schema: &str,
112 table: &str,
113 ssl_mode: &SslMode,
114 ssl_root_cert: &Option<String>,
115 ) -> ConnectorResult<(Vec<sea_schema::postgres::def::ColumnInfo>, Vec<String>)> {
116 let mut options = PgConnectOptions::new()
117 .username(username)
118 .password(password)
119 .host(host)
120 .port(port)
121 .database(database)
122 .ssl_mode(match ssl_mode {
123 SslMode::Disabled => PgSslMode::Disable,
124 SslMode::Preferred => PgSslMode::Prefer,
125 SslMode::Required => PgSslMode::Require,
126 SslMode::VerifyCa => PgSslMode::VerifyCa,
127 SslMode::VerifyFull => PgSslMode::VerifyFull,
128 });
129
130 if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
131 && let Some(root_cert) = ssl_root_cert
132 {
133 options = options.ssl_root_cert(root_cert.as_str());
134 }
135
136 let connection = PgPool::connect_with(options).await?;
137
138 let schema_discovery = SchemaDiscovery::new(connection.clone(), schema);
140 let empty_map: HashMap<String, Vec<String>> = HashMap::new();
141 let columns = schema_discovery
142 .discover_columns(
143 Alias::new(schema).into_iden(),
144 Alias::new(table).into_iden(),
145 &empty_map,
146 )
147 .await?;
148
149 let pk_columns = Self::discover_primary_key(&connection, schema, table).await?;
151
152 Ok((columns, pk_columns))
153 }
154
155 async fn discover_schema(
156 username: &str,
157 password: &str,
158 host: &str,
159 port: u16,
160 database: &str,
161 schema: &str,
162 table: &str,
163 ssl_mode: &SslMode,
164 ssl_root_cert: &Option<String>,
165 ) -> ConnectorResult<TableDef> {
166 let mut options = PgConnectOptions::new()
167 .username(username)
168 .password(password)
169 .host(host)
170 .port(port)
171 .database(database)
172 .ssl_mode(match ssl_mode {
173 SslMode::Disabled => PgSslMode::Disable,
174 SslMode::Preferred => PgSslMode::Prefer,
175 SslMode::Required => PgSslMode::Require,
176 SslMode::VerifyCa => PgSslMode::VerifyCa,
177 SslMode::VerifyFull => PgSslMode::VerifyFull,
178 });
179
180 if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
181 && let Some(root_cert) = ssl_root_cert
182 {
183 options = options.ssl_root_cert(root_cert.as_str());
184 }
185
186 let connection = PgPool::connect_with(options).await?;
187 let schema_discovery = SchemaDiscovery::new(connection, schema);
188 let empty_map = HashMap::new();
190 let table_schema = schema_discovery
191 .discover_table(
192 TableInfo {
193 name: table.to_owned(),
194 of_type: None,
195 },
196 &empty_map,
197 )
198 .await?;
199 Ok(table_schema)
200 }
201
202 pub async fn connect(
203 username: &str,
204 password: &str,
205 host: &str,
206 port: u16,
207 database: &str,
208 schema: &str,
209 table: &str,
210 ssl_mode: &SslMode,
211 ssl_root_cert: &Option<String>,
212 is_append_only: bool,
213 ) -> ConnectorResult<Self> {
214 tracing::debug!("connect to postgres external table");
215
216 let (columns, pk_names) = Self::discover_pk_and_full_columns(
217 username,
218 password,
219 host,
220 port,
221 database,
222 schema,
223 table,
224 ssl_mode,
225 ssl_root_cert,
226 )
227 .await?;
228
229 let mut column_descs = vec![];
230 for col in &columns {
231 let rw_data_type = sea_type_to_rw_type(&col.col_type)?;
232 let column_desc = if let Some(ref default_expr) = col.default {
233 let val_text = default_expr
236 .0
237 .split("::")
238 .map(|s| s.trim_matches('\''))
239 .next()
240 .expect("default value expression");
241
242 match ScalarImpl::from_text(val_text, &rw_data_type) {
243 Ok(scalar) => ColumnDesc::named_with_default_value(
244 col.name.clone(),
245 ColumnId::placeholder(),
246 rw_data_type.clone(),
247 Some(scalar),
248 ),
249 Err(err) => {
250 tracing::warn!(error=%err.as_report(), "failed to parse postgres default value expression, only constant is supported");
251 ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
252 }
253 }
254 } else {
255 ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
256 };
257 column_descs.push(column_desc);
258 }
259
260 if !is_append_only && pk_names.is_empty() {
262 return Err(anyhow!(
263 "Postgres table should define the primary key for non-append-only tables"
264 )
265 .into());
266 }
267
268 Ok(Self {
269 column_descs,
270 pk_names,
271 })
272 }
273
274 pub async fn type_mapping(
276 username: &str,
277 password: &str,
278 host: &str,
279 port: u16,
280 database: &str,
281 schema: &str,
282 table: &str,
283 ssl_mode: &SslMode,
284 ssl_root_cert: &Option<String>,
285 is_append_only: bool,
286 ) -> ConnectorResult<HashMap<String, tokio_postgres::types::Type>> {
287 tracing::debug!("connect to postgres external table to get type mapping");
288 let table_schema = Self::discover_schema(
289 username,
290 password,
291 host,
292 port,
293 database,
294 schema,
295 table,
296 ssl_mode,
297 ssl_root_cert,
298 )
299 .await?;
300 let mut column_name_to_pg_type = HashMap::new();
301 for col in &table_schema.columns {
302 let pg_type = sea_type_to_pg_type(&col.col_type)?;
303 column_name_to_pg_type.insert(col.name.clone(), pg_type);
304 }
305 if !is_append_only && table_schema.primary_key_constraints.is_empty() {
306 return Err(anyhow!(
307 "Postgres table should define the primary key for non-append-only tables"
308 )
309 .into());
310 }
311 Ok(column_name_to_pg_type)
312 }
313
314 pub fn column_descs(&self) -> &Vec<ColumnDesc> {
315 &self.column_descs
316 }
317
318 pub fn pk_names(&self) -> &Vec<String> {
319 &self.pk_names
320 }
321}
322
323impl fmt::Display for SslMode {
324 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
325 f.write_str(match self {
326 SslMode::Disabled => "disabled",
327 SslMode::Preferred => "preferred",
328 SslMode::Required => "required",
329 SslMode::VerifyCa => "verify-ca",
330 SslMode::VerifyFull => "verify-full",
331 })
332 }
333}
334
335pub async fn create_pg_client(
336 user: &str,
337 password: &str,
338 host: &str,
339 port: &str,
340 database: &str,
341 ssl_mode: &SslMode,
342 ssl_root_cert: &Option<String>,
343) -> anyhow::Result<PgClient> {
344 let mut pg_config = tokio_postgres::Config::new();
345 pg_config
346 .user(user)
347 .password(password)
348 .host(host)
349 .port(port.parse::<u16>().unwrap())
350 .dbname(database);
351
352 let (_verify_ca, verify_hostname) = match ssl_mode {
353 SslMode::VerifyCa => (true, false),
354 SslMode::VerifyFull => (true, true),
355 _ => (false, false),
356 };
357
358 #[cfg(not(madsim))]
359 let connector = match ssl_mode {
360 SslMode::Disabled => {
361 pg_config.ssl_mode(tokio_postgres::config::SslMode::Disable);
362 MaybeMakeTlsConnector::NoTls(NoTls)
363 }
364 SslMode::Preferred => {
365 pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
366 match SslConnector::builder(SslMethod::tls()) {
367 Ok(mut builder) => {
368 builder.set_verify(SslVerifyMode::NONE);
370 MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
371 }
372 Err(e) => {
373 tracing::warn!(error = %e.as_report(), "SSL connector error");
374 MaybeMakeTlsConnector::NoTls(NoTls)
375 }
376 }
377 }
378 SslMode::Required => {
379 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
380 let mut builder = SslConnector::builder(SslMethod::tls())?;
381 builder.set_verify(SslVerifyMode::NONE);
383 MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
384 }
385
386 SslMode::VerifyCa | SslMode::VerifyFull => {
387 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
388 let mut builder = SslConnector::builder(SslMethod::tls())?;
389 if let Some(ssl_root_cert) = ssl_root_cert {
390 builder.set_ca_file(ssl_root_cert).map_err(|e| {
391 anyhow!(format!("bad ssl root cert error: {}", e.to_report_string()))
392 })?;
393 }
394 let mut connector = MakeTlsConnector::new(builder.build());
395 if !verify_hostname {
396 connector.set_callback(|config, _| {
397 config.set_verify_hostname(false);
398 Ok(())
399 });
400 }
401 MaybeMakeTlsConnector::Tls(connector)
402 }
403 };
404 #[cfg(madsim)]
405 let connector = NoTls;
406
407 let (client, connection) = pg_config.connect(connector).await?;
408
409 tokio::spawn(async move {
410 if let Err(e) = connection.await {
411 tracing::error!(error = %e.as_report(), "postgres connection error");
412 }
413 });
414
415 Ok(client)
416}
417
418pub fn sea_type_to_rw_type(col_type: &SeaType) -> ConnectorResult<DataType> {
420 let dtype = match col_type {
421 SeaType::SmallInt | SeaType::SmallSerial => DataType::Int16,
422 SeaType::Integer | SeaType::Serial => DataType::Int32,
423 SeaType::BigInt | SeaType::BigSerial => DataType::Int64,
424 SeaType::Money | SeaType::Decimal(_) | SeaType::Numeric(_) => DataType::Decimal,
425 SeaType::Real => DataType::Float32,
426 SeaType::DoublePrecision => DataType::Float64,
427 SeaType::Varchar(_) | SeaType::Char(_) | SeaType::Text => DataType::Varchar,
428 SeaType::Bytea => DataType::Bytea,
429 SeaType::Timestamp(_) => DataType::Timestamp,
430 SeaType::TimestampWithTimeZone(_) => DataType::Timestamptz,
431 SeaType::Date => DataType::Date,
432 SeaType::Time(_) | SeaType::TimeWithTimeZone(_) => DataType::Time,
433 SeaType::Interval(_) => DataType::Interval,
434 SeaType::Boolean => DataType::Boolean,
435 SeaType::Point => DataType::Struct(StructType::new(vec![
436 ("x", DataType::Float32),
437 ("y", DataType::Float32),
438 ])),
439 SeaType::Uuid => DataType::Varchar,
440 SeaType::Xml => DataType::Varchar,
441 SeaType::Json => DataType::Jsonb,
442 SeaType::JsonBinary => DataType::Jsonb,
443 SeaType::Array(def) => {
444 let item_type = match def.col_type.as_ref() {
445 Some(ty) => sea_type_to_rw_type(ty.as_ref())?,
446 None => {
447 return Err(anyhow!("ARRAY type missing element type").into());
448 }
449 };
450
451 DataType::List(Box::new(item_type))
452 }
453 SeaType::PgLsn => DataType::Int64,
454 SeaType::Cidr
455 | SeaType::Inet
456 | SeaType::MacAddr
457 | SeaType::MacAddr8
458 | SeaType::Int4Range
459 | SeaType::Int8Range
460 | SeaType::NumRange
461 | SeaType::TsRange
462 | SeaType::TsTzRange
463 | SeaType::DateRange
464 | SeaType::Enum(_) => DataType::Varchar,
465 SeaType::Line
466 | SeaType::Lseg
467 | SeaType::Box
468 | SeaType::Path
469 | SeaType::Polygon
470 | SeaType::Circle
471 | SeaType::Bit(_)
472 | SeaType::VarBit(_)
473 | SeaType::TsVector
474 | SeaType::TsQuery => {
475 bail!("{:?} type not supported", col_type);
476 }
477 SeaType::Unknown(name) => {
478 tracing::warn!("Unknown Postgres data type: {name}, map to varchar");
480 DataType::Varchar
481 }
482 };
483
484 Ok(dtype)
485}
486
487fn sea_type_to_pg_type(sea_type: &SeaType) -> ConnectorResult<tokio_postgres::types::Type> {
492 use tokio_postgres::types::Type as PgType;
493 match sea_type {
494 SeaType::SmallInt => Ok(PgType::INT2),
495 SeaType::Integer => Ok(PgType::INT4),
496 SeaType::BigInt => Ok(PgType::INT8),
497 SeaType::Decimal(_) => Ok(PgType::NUMERIC),
498 SeaType::Numeric(_) => Ok(PgType::NUMERIC),
499 SeaType::Real => Ok(PgType::FLOAT4),
500 SeaType::DoublePrecision => Ok(PgType::FLOAT8),
501 SeaType::Varchar(_) => Ok(PgType::VARCHAR),
502 SeaType::Char(_) => Ok(PgType::CHAR),
503 SeaType::Text => Ok(PgType::TEXT),
504 SeaType::Bytea => Ok(PgType::BYTEA),
505 SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP),
506 SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ),
507 SeaType::Date => Ok(PgType::DATE),
508 SeaType::Time(_) => Ok(PgType::TIME),
509 SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ),
510 SeaType::Interval(_) => Ok(PgType::INTERVAL),
511 SeaType::Boolean => Ok(PgType::BOOL),
512 SeaType::Point => Ok(PgType::POINT),
513 SeaType::Uuid => Ok(PgType::UUID),
514 SeaType::Json => Ok(PgType::JSON),
515 SeaType::JsonBinary => Ok(PgType::JSONB),
516 SeaType::Array(t) => {
517 let Some(t) = t.col_type.as_ref() else {
518 bail!("missing array type")
519 };
520 match t.as_ref() {
521 SeaType::SmallInt => Ok(PgType::INT2_ARRAY),
523 SeaType::Integer => Ok(PgType::INT4_ARRAY),
524 SeaType::BigInt => Ok(PgType::INT8_ARRAY),
525 SeaType::Decimal(_) => Ok(PgType::NUMERIC_ARRAY),
526 SeaType::Numeric(_) => Ok(PgType::NUMERIC_ARRAY),
527 SeaType::Real => Ok(PgType::FLOAT4_ARRAY),
528 SeaType::DoublePrecision => Ok(PgType::FLOAT8_ARRAY),
529 SeaType::Varchar(_) => Ok(PgType::VARCHAR_ARRAY),
530 SeaType::Char(_) => Ok(PgType::CHAR_ARRAY),
531 SeaType::Text => Ok(PgType::TEXT_ARRAY),
532 SeaType::Bytea => Ok(PgType::BYTEA_ARRAY),
533 SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP_ARRAY),
534 SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ_ARRAY),
535 SeaType::Date => Ok(PgType::DATE_ARRAY),
536 SeaType::Time(_) => Ok(PgType::TIME_ARRAY),
537 SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ_ARRAY),
538 SeaType::Interval(_) => Ok(PgType::INTERVAL_ARRAY),
539 SeaType::Boolean => Ok(PgType::BOOL_ARRAY),
540 SeaType::Point => Ok(PgType::POINT_ARRAY),
541 SeaType::Uuid => Ok(PgType::UUID_ARRAY),
542 SeaType::Json => Ok(PgType::JSON_ARRAY),
543 SeaType::JsonBinary => Ok(PgType::JSONB_ARRAY),
544 SeaType::Array(_) => bail!("nested array type is not supported"),
545 SeaType::Unknown(name) => {
546 Ok(PgType::new(
548 name.clone(),
549 0,
550 PgKind::Array(PgType::new(
551 name.clone(),
552 0,
553 PgKind::Enum(vec![]),
554 "".into(),
555 )),
556 "".into(),
557 ))
558 }
559 _ => bail!("unsupported array type: {:?}", t),
560 }
561 }
562 SeaType::Unknown(name) => {
563 Ok(PgType::new(
565 name.clone(),
566 0,
567 PgKind::Enum(vec![]),
568 "".into(),
569 ))
570 }
571 _ => bail!("unsupported type: {:?}", sea_type),
572 }
573}