risingwave_connector/connector_common/
postgres.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt;
17
18use anyhow::{Context, anyhow};
19use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
20use postgres_openssl::MakeTlsConnector;
21use risingwave_common::bail;
22use risingwave_common::catalog::{ColumnDesc, ColumnId};
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use sea_schema::postgres::def::{ColumnType as SeaType, TableDef, TableInfo};
25use sea_schema::postgres::discovery::SchemaDiscovery;
26use sea_schema::sea_query::{Alias, IntoIden};
27use serde::Deserialize;
28use sqlx::postgres::{PgConnectOptions, PgSslMode};
29use sqlx::{PgPool, Row};
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Kind as PgKind;
32use tokio_postgres::{Client as PgClient, NoTls};
33
34#[cfg(not(madsim))]
35use super::maybe_tls_connector::MaybeMakeTlsConnector;
36use crate::error::ConnectorResult;
37use crate::sink::postgres::TcpKeepaliveConfig;
38
39/// SQL query to discover primary key columns directly from PostgreSQL system tables.
40/// This bypasses querying `information_schema.table_constraints` to avoid permission issues.
41const DISCOVER_PRIMARY_KEY_QUERY: &str = r#"
42    SELECT a.attname as column_name
43    FROM pg_index i
44    JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
45    WHERE i.indrelid = ($1 || '.' || $2)::regclass
46      AND i.indisprimary = true
47    ORDER BY array_position(i.indkey, a.attnum)
48"#;
49
50#[derive(Debug, Clone, PartialEq, Deserialize, Default)]
51#[serde(rename_all = "lowercase")]
52pub enum SslMode {
53    #[serde(alias = "disable")]
54    Disabled,
55    #[serde(alias = "prefer")]
56    #[default]
57    Preferred,
58    #[serde(alias = "require")]
59    Required,
60    /// verify that the server is trustworthy by checking the certificate chain
61    /// up to the root certificate stored on the client.
62    #[serde(alias = "verify-ca")]
63    VerifyCa,
64    /// Besides verify the certificate, will also verify that the serverhost name
65    /// matches the name stored in the server certificate.
66    #[serde(alias = "verify-full")]
67    VerifyFull,
68}
69
70pub struct PostgresExternalTable {
71    column_descs: Vec<ColumnDesc>,
72    pk_names: Vec<String>,
73}
74
75impl PostgresExternalTable {
76    /// Discover primary key columns directly from PostgreSQL system tables.
77    /// This bypasses querying `information_schema.table_constraints` to avoid requiring table owner permissions.
78    async fn discover_primary_key(
79        connection: &PgPool,
80        schema_name: &str,
81        table_name: &str,
82    ) -> ConnectorResult<Vec<String>> {
83        let rows = sqlx::query(DISCOVER_PRIMARY_KEY_QUERY)
84            .bind(schema_name)
85            .bind(table_name)
86            .fetch_all(connection)
87            .await
88            .context("Failed to discover primary key columns")?;
89
90        let pk_columns = rows
91            .into_iter()
92            .map(|row| row.get::<String, _>("column_name"))
93            .collect();
94
95        Ok(pk_columns)
96    }
97
98    /// Discover schema with workaround for primary key discovery
99    /// This method uses direct PostgreSQL system table queries for primary keys
100    /// to avoid permission issues when querying `information_schema.table_constraints`
101    async fn discover_pk_and_full_columns(
102        username: &str,
103        password: &str,
104        host: &str,
105        port: u16,
106        database: &str,
107        schema: &str,
108        table: &str,
109        ssl_mode: &SslMode,
110        ssl_root_cert: &Option<String>,
111    ) -> ConnectorResult<(Vec<sea_schema::postgres::def::ColumnInfo>, Vec<String>)> {
112        let mut options = PgConnectOptions::new()
113            .username(username)
114            .password(password)
115            .host(host)
116            .port(port)
117            .database(database)
118            .ssl_mode(match ssl_mode {
119                SslMode::Disabled => PgSslMode::Disable,
120                SslMode::Preferred => PgSslMode::Prefer,
121                SslMode::Required => PgSslMode::Require,
122                SslMode::VerifyCa => PgSslMode::VerifyCa,
123                SslMode::VerifyFull => PgSslMode::VerifyFull,
124            });
125
126        if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
127            && let Some(root_cert) = ssl_root_cert
128        {
129            options = options.ssl_root_cert(root_cert.as_str());
130        }
131
132        let connection = PgPool::connect_with(options).await?;
133
134        // Use sea-schema only for column discovery (no permission issues)
135        let schema_discovery = SchemaDiscovery::new(connection.clone(), schema);
136        let empty_map: HashMap<String, Vec<String>> = HashMap::new();
137        let columns = schema_discovery
138            .discover_columns(
139                Alias::new(schema).into_iden(),
140                Alias::new(table).into_iden(),
141                &empty_map,
142            )
143            .await?;
144
145        // Use direct system table query for primary key discovery
146        let pk_columns = Self::discover_primary_key(&connection, schema, table).await?;
147
148        Ok((columns, pk_columns))
149    }
150
151    async fn discover_schema(
152        username: &str,
153        password: &str,
154        host: &str,
155        port: u16,
156        database: &str,
157        schema: &str,
158        table: &str,
159        ssl_mode: &SslMode,
160        ssl_root_cert: &Option<String>,
161    ) -> ConnectorResult<TableDef> {
162        let mut options = PgConnectOptions::new()
163            .username(username)
164            .password(password)
165            .host(host)
166            .port(port)
167            .database(database)
168            .ssl_mode(match ssl_mode {
169                SslMode::Disabled => PgSslMode::Disable,
170                SslMode::Preferred => PgSslMode::Prefer,
171                SslMode::Required => PgSslMode::Require,
172                SslMode::VerifyCa => PgSslMode::VerifyCa,
173                SslMode::VerifyFull => PgSslMode::VerifyFull,
174            });
175
176        if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
177            && let Some(root_cert) = ssl_root_cert
178        {
179            options = options.ssl_root_cert(root_cert.as_str());
180        }
181
182        let connection = PgPool::connect_with(options).await?;
183        let schema_discovery = SchemaDiscovery::new(connection, schema);
184        // fetch column schema and primary key
185        let empty_map = HashMap::new();
186        let table_schema = schema_discovery
187            .discover_table(
188                TableInfo {
189                    name: table.to_owned(),
190                    of_type: None,
191                },
192                &empty_map,
193            )
194            .await?;
195        Ok(table_schema)
196    }
197
198    pub async fn connect(
199        username: &str,
200        password: &str,
201        host: &str,
202        port: u16,
203        database: &str,
204        schema: &str,
205        table: &str,
206        ssl_mode: &SslMode,
207        ssl_root_cert: &Option<String>,
208        is_append_only: bool,
209    ) -> ConnectorResult<Self> {
210        tracing::debug!("connect to postgres external table");
211
212        let (columns, pk_names) = Self::discover_pk_and_full_columns(
213            username,
214            password,
215            host,
216            port,
217            database,
218            schema,
219            table,
220            ssl_mode,
221            ssl_root_cert,
222        )
223        .await?;
224
225        let mut column_descs = vec![];
226        for col in &columns {
227            let rw_data_type = sea_type_to_rw_type(&col.col_type)?;
228            let column_desc = if let Some(ref default_expr) = col.default {
229                // parse the value of "column_default" field in information_schema.columns,
230                // non number data type will be stored as "'value'::type"
231                let val_text = default_expr
232                    .0
233                    .split("::")
234                    .map(|s| s.trim_matches('\''))
235                    .next()
236                    .expect("default value expression");
237
238                match ScalarImpl::from_text(val_text, &rw_data_type) {
239                    Ok(scalar) => ColumnDesc::named_with_default_value(
240                        col.name.clone(),
241                        ColumnId::placeholder(),
242                        rw_data_type.clone(),
243                        Some(scalar),
244                    ),
245                    Err(err) => {
246                        tracing::warn!(error=%err.as_report(), "failed to parse postgres default value expression, only constant is supported");
247                        ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
248                    }
249                }
250            } else {
251                ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
252            };
253            column_descs.push(column_desc);
254        }
255
256        // Check primary key existence using the directly discovered pk_names
257        if !is_append_only && pk_names.is_empty() {
258            return Err(anyhow!(
259                "Postgres table should define the primary key for non-append-only tables"
260            )
261            .into());
262        }
263
264        Ok(Self {
265            column_descs,
266            pk_names,
267        })
268    }
269
270    // return the mapping from column name to pg type, the pg type is used for writing data to postgres
271    pub async fn type_mapping(
272        username: &str,
273        password: &str,
274        host: &str,
275        port: u16,
276        database: &str,
277        schema: &str,
278        table: &str,
279        ssl_mode: &SslMode,
280        ssl_root_cert: &Option<String>,
281        is_append_only: bool,
282    ) -> ConnectorResult<HashMap<String, tokio_postgres::types::Type>> {
283        tracing::debug!("connect to postgres external table to get type mapping");
284        let table_schema = Self::discover_schema(
285            username,
286            password,
287            host,
288            port,
289            database,
290            schema,
291            table,
292            ssl_mode,
293            ssl_root_cert,
294        )
295        .await?;
296        let mut column_name_to_pg_type = HashMap::new();
297        for col in &table_schema.columns {
298            let pg_type = sea_type_to_pg_type(&col.col_type)?;
299            column_name_to_pg_type.insert(col.name.clone(), pg_type);
300        }
301        if !is_append_only && table_schema.primary_key_constraints.is_empty() {
302            return Err(anyhow!(
303                "Postgres table should define the primary key for non-append-only tables"
304            )
305            .into());
306        }
307        Ok(column_name_to_pg_type)
308    }
309
310    pub fn column_descs(&self) -> &Vec<ColumnDesc> {
311        &self.column_descs
312    }
313
314    pub fn pk_names(&self) -> &Vec<String> {
315        &self.pk_names
316    }
317}
318
319impl fmt::Display for SslMode {
320    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321        f.write_str(match self {
322            SslMode::Disabled => "disabled",
323            SslMode::Preferred => "preferred",
324            SslMode::Required => "required",
325            SslMode::VerifyCa => "verify-ca",
326            SslMode::VerifyFull => "verify-full",
327        })
328    }
329}
330
331impl std::str::FromStr for SslMode {
332    type Err = serde_json::Error;
333
334    fn from_str(s: &str) -> Result<Self, Self::Err> {
335        serde_json::from_value(serde_json::Value::String(s.to_owned()))
336    }
337}
338
339pub async fn create_pg_client(
340    user: &str,
341    password: &str,
342    host: &str,
343    port: &str,
344    database: &str,
345    ssl_mode: &SslMode,
346    ssl_root_cert: &Option<String>,
347    tcp_keepalive: Option<TcpKeepaliveConfig>,
348) -> anyhow::Result<PgClient> {
349    let mut pg_config = tokio_postgres::Config::new();
350    pg_config
351        .user(user)
352        .password(password)
353        .host(host)
354        .port(port.parse::<u16>().unwrap())
355        .dbname(database);
356
357    // Configure TCP keepalive if provided
358    if let Some(keepalive) = tcp_keepalive {
359        pg_config.keepalives(true);
360        pg_config.keepalives_idle(std::time::Duration::from_secs(
361            keepalive.tcp_keepalive_idle as u64,
362        ));
363        #[cfg(not(target_os = "windows"))]
364        {
365            pg_config.keepalives_interval(std::time::Duration::from_secs(
366                keepalive.tcp_keepalive_interval as u64,
367            ));
368            pg_config.keepalives_retries(keepalive.tcp_keepalive_count);
369        }
370        tracing::info!(
371            "TCP keepalive enabled: idle={}s, interval={}s, retries={}",
372            keepalive.tcp_keepalive_idle,
373            keepalive.tcp_keepalive_interval,
374            keepalive.tcp_keepalive_count
375        );
376    }
377
378    let (_verify_ca, verify_hostname) = match ssl_mode {
379        SslMode::VerifyCa => (true, false),
380        SslMode::VerifyFull => (true, true),
381        _ => (false, false),
382    };
383
384    #[cfg(not(madsim))]
385    let connector = match ssl_mode {
386        SslMode::Disabled => {
387            pg_config.ssl_mode(tokio_postgres::config::SslMode::Disable);
388            MaybeMakeTlsConnector::NoTls(NoTls)
389        }
390        SslMode::Preferred => {
391            pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
392            match SslConnector::builder(SslMethod::tls()) {
393                Ok(mut builder) => {
394                    // disable certificate verification for `prefer`
395                    builder.set_verify(SslVerifyMode::NONE);
396                    MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
397                }
398                Err(e) => {
399                    tracing::warn!(error = %e.as_report(), "SSL connector error");
400                    MaybeMakeTlsConnector::NoTls(NoTls)
401                }
402            }
403        }
404        SslMode::Required => {
405            pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
406            let mut builder = SslConnector::builder(SslMethod::tls())?;
407            // disable certificate verification for `require`
408            builder.set_verify(SslVerifyMode::NONE);
409            MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
410        }
411
412        SslMode::VerifyCa | SslMode::VerifyFull => {
413            pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
414            let mut builder = SslConnector::builder(SslMethod::tls())?;
415            if let Some(ssl_root_cert) = ssl_root_cert {
416                builder.set_ca_file(ssl_root_cert).map_err(|e| {
417                    anyhow!(format!("bad ssl root cert error: {}", e.to_report_string()))
418                })?;
419            }
420            let mut connector = MakeTlsConnector::new(builder.build());
421            if !verify_hostname {
422                connector.set_callback(|config, _| {
423                    config.set_verify_hostname(false);
424                    Ok(())
425                });
426            }
427            MaybeMakeTlsConnector::Tls(connector)
428        }
429    };
430    #[cfg(madsim)]
431    let connector = NoTls;
432
433    let (client, connection) = pg_config.connect(connector).await?;
434
435    tokio::spawn(async move {
436        if let Err(e) = connection.await {
437            tracing::error!(error = %e.as_report(), "postgres connection error");
438        }
439    });
440
441    Ok(client)
442}
443
444// Used for both source and sink connector
445pub fn sea_type_to_rw_type(col_type: &SeaType) -> ConnectorResult<DataType> {
446    let dtype = match col_type {
447        SeaType::SmallInt | SeaType::SmallSerial => DataType::Int16,
448        SeaType::Integer | SeaType::Serial => DataType::Int32,
449        SeaType::BigInt | SeaType::BigSerial => DataType::Int64,
450        SeaType::Money | SeaType::Decimal(_) | SeaType::Numeric(_) => DataType::Decimal,
451        SeaType::Real => DataType::Float32,
452        SeaType::DoublePrecision => DataType::Float64,
453        SeaType::Varchar(_) | SeaType::Char(_) | SeaType::Text => DataType::Varchar,
454        SeaType::Bytea => DataType::Bytea,
455        SeaType::Timestamp(_) => DataType::Timestamp,
456        SeaType::TimestampWithTimeZone(_) => DataType::Timestamptz,
457        SeaType::Date => DataType::Date,
458        SeaType::Time(_) | SeaType::TimeWithTimeZone(_) => DataType::Time,
459        SeaType::Interval(_) => DataType::Interval,
460        SeaType::Boolean => DataType::Boolean,
461        SeaType::Point => DataType::Struct(StructType::new(vec![
462            ("x", DataType::Float32),
463            ("y", DataType::Float32),
464        ])),
465        SeaType::Uuid => DataType::Varchar,
466        SeaType::Xml => DataType::Varchar,
467        SeaType::Json => DataType::Jsonb,
468        SeaType::JsonBinary => DataType::Jsonb,
469        SeaType::Array(def) => {
470            let item_type = match def.col_type.as_ref() {
471                Some(ty) => sea_type_to_rw_type(ty.as_ref())?,
472                None => {
473                    return Err(anyhow!("ARRAY type missing element type").into());
474                }
475            };
476
477            DataType::list(item_type)
478        }
479        SeaType::PgLsn => DataType::Int64,
480        SeaType::Cidr
481        | SeaType::Inet
482        | SeaType::MacAddr
483        | SeaType::MacAddr8
484        | SeaType::Int4Range
485        | SeaType::Int8Range
486        | SeaType::NumRange
487        | SeaType::TsRange
488        | SeaType::TsTzRange
489        | SeaType::DateRange
490        | SeaType::Enum(_) => DataType::Varchar,
491        SeaType::Line
492        | SeaType::Lseg
493        | SeaType::Box
494        | SeaType::Path
495        | SeaType::Polygon
496        | SeaType::Circle
497        | SeaType::Bit(_)
498        | SeaType::VarBit(_)
499        | SeaType::TsVector
500        | SeaType::TsQuery => {
501            bail!("{:?} type not supported", col_type);
502        }
503        SeaType::Unknown(name) => {
504            // NOTES: user-defined enum type is classified as `Unknown`
505            tracing::warn!("Unknown Postgres data type: {name}, map to varchar");
506            DataType::Varchar
507        }
508    };
509
510    Ok(dtype)
511}
512
513// Used for sink connector
514// We use `sea-schema` for table schema discovery.
515// So we have to map `sea-schema` pg types
516// to `tokio-postgres` pg types (which we use for query binding).
517fn sea_type_to_pg_type(sea_type: &SeaType) -> ConnectorResult<tokio_postgres::types::Type> {
518    use tokio_postgres::types::Type as PgType;
519    match sea_type {
520        SeaType::SmallInt => Ok(PgType::INT2),
521        SeaType::Integer => Ok(PgType::INT4),
522        SeaType::BigInt => Ok(PgType::INT8),
523        SeaType::Decimal(_) => Ok(PgType::NUMERIC),
524        SeaType::Numeric(_) => Ok(PgType::NUMERIC),
525        SeaType::Real => Ok(PgType::FLOAT4),
526        SeaType::DoublePrecision => Ok(PgType::FLOAT8),
527        SeaType::Varchar(_) => Ok(PgType::VARCHAR),
528        SeaType::Char(_) => Ok(PgType::CHAR),
529        SeaType::Text => Ok(PgType::TEXT),
530        SeaType::Bytea => Ok(PgType::BYTEA),
531        SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP),
532        SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ),
533        SeaType::Date => Ok(PgType::DATE),
534        SeaType::Time(_) => Ok(PgType::TIME),
535        SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ),
536        SeaType::Interval(_) => Ok(PgType::INTERVAL),
537        SeaType::Boolean => Ok(PgType::BOOL),
538        SeaType::Point => Ok(PgType::POINT),
539        SeaType::Uuid => Ok(PgType::UUID),
540        SeaType::Json => Ok(PgType::JSON),
541        SeaType::JsonBinary => Ok(PgType::JSONB),
542        SeaType::Array(t) => {
543            let Some(t) = t.col_type.as_ref() else {
544                bail!("missing array type")
545            };
546            match t.as_ref() {
547                // RW only supports 1 level of nesting.
548                SeaType::SmallInt => Ok(PgType::INT2_ARRAY),
549                SeaType::Integer => Ok(PgType::INT4_ARRAY),
550                SeaType::BigInt => Ok(PgType::INT8_ARRAY),
551                SeaType::Decimal(_) => Ok(PgType::NUMERIC_ARRAY),
552                SeaType::Numeric(_) => Ok(PgType::NUMERIC_ARRAY),
553                SeaType::Real => Ok(PgType::FLOAT4_ARRAY),
554                SeaType::DoublePrecision => Ok(PgType::FLOAT8_ARRAY),
555                SeaType::Varchar(_) => Ok(PgType::VARCHAR_ARRAY),
556                SeaType::Char(_) => Ok(PgType::CHAR_ARRAY),
557                SeaType::Text => Ok(PgType::TEXT_ARRAY),
558                SeaType::Bytea => Ok(PgType::BYTEA_ARRAY),
559                SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP_ARRAY),
560                SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ_ARRAY),
561                SeaType::Date => Ok(PgType::DATE_ARRAY),
562                SeaType::Time(_) => Ok(PgType::TIME_ARRAY),
563                SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ_ARRAY),
564                SeaType::Interval(_) => Ok(PgType::INTERVAL_ARRAY),
565                SeaType::Boolean => Ok(PgType::BOOL_ARRAY),
566                SeaType::Point => Ok(PgType::POINT_ARRAY),
567                SeaType::Uuid => Ok(PgType::UUID_ARRAY),
568                SeaType::Json => Ok(PgType::JSON_ARRAY),
569                SeaType::JsonBinary => Ok(PgType::JSONB_ARRAY),
570                SeaType::Array(_) => bail!("nested array type is not supported"),
571                SeaType::Unknown(name) => {
572                    // Treat as enum type
573                    Ok(PgType::new(
574                        name.clone(),
575                        0,
576                        PgKind::Array(PgType::new(
577                            name.clone(),
578                            0,
579                            PgKind::Enum(vec![]),
580                            "".into(),
581                        )),
582                        "".into(),
583                    ))
584                }
585                _ => bail!("unsupported array type: {:?}", t),
586            }
587        }
588        SeaType::Unknown(name) => {
589            // Treat as enum type
590            Ok(PgType::new(
591                name.clone(),
592                0,
593                PgKind::Enum(vec![]),
594                "".into(),
595            ))
596        }
597        _ => bail!("unsupported type: {:?}", sea_type),
598    }
599}