risingwave_connector/connector_common/
postgres.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt;
17
18use anyhow::{Context, anyhow};
19use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
20use postgres_openssl::MakeTlsConnector;
21use risingwave_common::bail;
22use risingwave_common::catalog::{ColumnDesc, ColumnId};
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use sea_schema::postgres::def::{ColumnType as SeaType, TableDef, TableInfo};
25use sea_schema::postgres::discovery::SchemaDiscovery;
26use sea_schema::sea_query::{Alias, IntoIden};
27use serde::Deserialize;
28use sqlx::postgres::{PgConnectOptions, PgSslMode};
29use sqlx::{PgPool, Row};
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Kind as PgKind;
32use tokio_postgres::{Client as PgClient, NoTls};
33
34#[cfg(not(madsim))]
35use super::maybe_tls_connector::MaybeMakeTlsConnector;
36use crate::error::ConnectorResult;
37
38/// SQL query to discover primary key columns directly from PostgreSQL system tables.
39/// This bypasses querying `information_schema.table_constraints` to avoid permission issues.
40const DISCOVER_PRIMARY_KEY_QUERY: &str = r#"
41    SELECT a.attname as column_name
42    FROM pg_index i
43    JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
44    WHERE i.indrelid = ($1 || '.' || $2)::regclass
45      AND i.indisprimary = true
46    ORDER BY array_position(i.indkey, a.attnum)
47"#;
48
49#[derive(Debug, Clone, PartialEq, Deserialize, Default)]
50#[serde(rename_all = "lowercase")]
51pub enum SslMode {
52    #[serde(alias = "disable")]
53    Disabled,
54    #[serde(alias = "prefer")]
55    #[default]
56    Preferred,
57    #[serde(alias = "require")]
58    Required,
59    /// verify that the server is trustworthy by checking the certificate chain
60    /// up to the root certificate stored on the client.
61    #[serde(alias = "verify-ca")]
62    VerifyCa,
63    /// Besides verify the certificate, will also verify that the serverhost name
64    /// matches the name stored in the server certificate.
65    #[serde(alias = "verify-full")]
66    VerifyFull,
67}
68
69pub struct PostgresExternalTable {
70    column_descs: Vec<ColumnDesc>,
71    pk_names: Vec<String>,
72}
73
74impl PostgresExternalTable {
75    /// Discover primary key columns directly from PostgreSQL system tables.
76    /// This bypasses querying `information_schema.table_constraints` to avoid requiring table owner permissions.
77    async fn discover_primary_key(
78        connection: &PgPool,
79        schema_name: &str,
80        table_name: &str,
81    ) -> ConnectorResult<Vec<String>> {
82        let rows = sqlx::query(DISCOVER_PRIMARY_KEY_QUERY)
83            .bind(schema_name)
84            .bind(table_name)
85            .fetch_all(connection)
86            .await
87            .context("Failed to discover primary key columns")?;
88
89        let pk_columns = rows
90            .into_iter()
91            .map(|row| row.get::<String, _>("column_name"))
92            .collect();
93
94        Ok(pk_columns)
95    }
96
97    /// Discover schema with workaround for primary key discovery
98    /// This method uses direct PostgreSQL system table queries for primary keys
99    /// to avoid permission issues when querying `information_schema.table_constraints`
100    async fn discover_pk_and_full_columns(
101        username: &str,
102        password: &str,
103        host: &str,
104        port: u16,
105        database: &str,
106        schema: &str,
107        table: &str,
108        ssl_mode: &SslMode,
109        ssl_root_cert: &Option<String>,
110    ) -> ConnectorResult<(Vec<sea_schema::postgres::def::ColumnInfo>, Vec<String>)> {
111        let mut options = PgConnectOptions::new()
112            .username(username)
113            .password(password)
114            .host(host)
115            .port(port)
116            .database(database)
117            .ssl_mode(match ssl_mode {
118                SslMode::Disabled => PgSslMode::Disable,
119                SslMode::Preferred => PgSslMode::Prefer,
120                SslMode::Required => PgSslMode::Require,
121                SslMode::VerifyCa => PgSslMode::VerifyCa,
122                SslMode::VerifyFull => PgSslMode::VerifyFull,
123            });
124
125        if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
126            && let Some(root_cert) = ssl_root_cert
127        {
128            options = options.ssl_root_cert(root_cert.as_str());
129        }
130
131        let connection = PgPool::connect_with(options).await?;
132
133        // Use sea-schema only for column discovery (no permission issues)
134        let schema_discovery = SchemaDiscovery::new(connection.clone(), schema);
135        let empty_map: HashMap<String, Vec<String>> = HashMap::new();
136        let columns = schema_discovery
137            .discover_columns(
138                Alias::new(schema).into_iden(),
139                Alias::new(table).into_iden(),
140                &empty_map,
141            )
142            .await?;
143
144        // Use direct system table query for primary key discovery
145        let pk_columns = Self::discover_primary_key(&connection, schema, table).await?;
146
147        Ok((columns, pk_columns))
148    }
149
150    async fn discover_schema(
151        username: &str,
152        password: &str,
153        host: &str,
154        port: u16,
155        database: &str,
156        schema: &str,
157        table: &str,
158        ssl_mode: &SslMode,
159        ssl_root_cert: &Option<String>,
160    ) -> ConnectorResult<TableDef> {
161        let mut options = PgConnectOptions::new()
162            .username(username)
163            .password(password)
164            .host(host)
165            .port(port)
166            .database(database)
167            .ssl_mode(match ssl_mode {
168                SslMode::Disabled => PgSslMode::Disable,
169                SslMode::Preferred => PgSslMode::Prefer,
170                SslMode::Required => PgSslMode::Require,
171                SslMode::VerifyCa => PgSslMode::VerifyCa,
172                SslMode::VerifyFull => PgSslMode::VerifyFull,
173            });
174
175        if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
176            && let Some(root_cert) = ssl_root_cert
177        {
178            options = options.ssl_root_cert(root_cert.as_str());
179        }
180
181        let connection = PgPool::connect_with(options).await?;
182        let schema_discovery = SchemaDiscovery::new(connection, schema);
183        // fetch column schema and primary key
184        let empty_map = HashMap::new();
185        let table_schema = schema_discovery
186            .discover_table(
187                TableInfo {
188                    name: table.to_owned(),
189                    of_type: None,
190                },
191                &empty_map,
192            )
193            .await?;
194        Ok(table_schema)
195    }
196
197    pub async fn connect(
198        username: &str,
199        password: &str,
200        host: &str,
201        port: u16,
202        database: &str,
203        schema: &str,
204        table: &str,
205        ssl_mode: &SslMode,
206        ssl_root_cert: &Option<String>,
207        is_append_only: bool,
208    ) -> ConnectorResult<Self> {
209        tracing::debug!("connect to postgres external table");
210
211        let (columns, pk_names) = Self::discover_pk_and_full_columns(
212            username,
213            password,
214            host,
215            port,
216            database,
217            schema,
218            table,
219            ssl_mode,
220            ssl_root_cert,
221        )
222        .await?;
223
224        let mut column_descs = vec![];
225        for col in &columns {
226            let rw_data_type = sea_type_to_rw_type(&col.col_type)?;
227            let column_desc = if let Some(ref default_expr) = col.default {
228                // parse the value of "column_default" field in information_schema.columns,
229                // non number data type will be stored as "'value'::type"
230                let val_text = default_expr
231                    .0
232                    .split("::")
233                    .map(|s| s.trim_matches('\''))
234                    .next()
235                    .expect("default value expression");
236
237                match ScalarImpl::from_text(val_text, &rw_data_type) {
238                    Ok(scalar) => ColumnDesc::named_with_default_value(
239                        col.name.clone(),
240                        ColumnId::placeholder(),
241                        rw_data_type.clone(),
242                        Some(scalar),
243                    ),
244                    Err(err) => {
245                        tracing::warn!(error=%err.as_report(), "failed to parse postgres default value expression, only constant is supported");
246                        ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
247                    }
248                }
249            } else {
250                ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
251            };
252            column_descs.push(column_desc);
253        }
254
255        // Check primary key existence using the directly discovered pk_names
256        if !is_append_only && pk_names.is_empty() {
257            return Err(anyhow!(
258                "Postgres table should define the primary key for non-append-only tables"
259            )
260            .into());
261        }
262
263        Ok(Self {
264            column_descs,
265            pk_names,
266        })
267    }
268
269    // return the mapping from column name to pg type, the pg type is used for writing data to postgres
270    pub async fn type_mapping(
271        username: &str,
272        password: &str,
273        host: &str,
274        port: u16,
275        database: &str,
276        schema: &str,
277        table: &str,
278        ssl_mode: &SslMode,
279        ssl_root_cert: &Option<String>,
280        is_append_only: bool,
281    ) -> ConnectorResult<HashMap<String, tokio_postgres::types::Type>> {
282        tracing::debug!("connect to postgres external table to get type mapping");
283        let table_schema = Self::discover_schema(
284            username,
285            password,
286            host,
287            port,
288            database,
289            schema,
290            table,
291            ssl_mode,
292            ssl_root_cert,
293        )
294        .await?;
295        let mut column_name_to_pg_type = HashMap::new();
296        for col in &table_schema.columns {
297            let pg_type = sea_type_to_pg_type(&col.col_type)?;
298            column_name_to_pg_type.insert(col.name.clone(), pg_type);
299        }
300        if !is_append_only && table_schema.primary_key_constraints.is_empty() {
301            return Err(anyhow!(
302                "Postgres table should define the primary key for non-append-only tables"
303            )
304            .into());
305        }
306        Ok(column_name_to_pg_type)
307    }
308
309    pub fn column_descs(&self) -> &Vec<ColumnDesc> {
310        &self.column_descs
311    }
312
313    pub fn pk_names(&self) -> &Vec<String> {
314        &self.pk_names
315    }
316}
317
318impl fmt::Display for SslMode {
319    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320        f.write_str(match self {
321            SslMode::Disabled => "disabled",
322            SslMode::Preferred => "preferred",
323            SslMode::Required => "required",
324            SslMode::VerifyCa => "verify-ca",
325            SslMode::VerifyFull => "verify-full",
326        })
327    }
328}
329
330pub async fn create_pg_client(
331    user: &str,
332    password: &str,
333    host: &str,
334    port: &str,
335    database: &str,
336    ssl_mode: &SslMode,
337    ssl_root_cert: &Option<String>,
338) -> anyhow::Result<PgClient> {
339    let mut pg_config = tokio_postgres::Config::new();
340    pg_config
341        .user(user)
342        .password(password)
343        .host(host)
344        .port(port.parse::<u16>().unwrap())
345        .dbname(database);
346
347    let (_verify_ca, verify_hostname) = match ssl_mode {
348        SslMode::VerifyCa => (true, false),
349        SslMode::VerifyFull => (true, true),
350        _ => (false, false),
351    };
352
353    #[cfg(not(madsim))]
354    let connector = match ssl_mode {
355        SslMode::Disabled => {
356            pg_config.ssl_mode(tokio_postgres::config::SslMode::Disable);
357            MaybeMakeTlsConnector::NoTls(NoTls)
358        }
359        SslMode::Preferred => {
360            pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
361            match SslConnector::builder(SslMethod::tls()) {
362                Ok(mut builder) => {
363                    // disable certificate verification for `prefer`
364                    builder.set_verify(SslVerifyMode::NONE);
365                    MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
366                }
367                Err(e) => {
368                    tracing::warn!(error = %e.as_report(), "SSL connector error");
369                    MaybeMakeTlsConnector::NoTls(NoTls)
370                }
371            }
372        }
373        SslMode::Required => {
374            pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
375            let mut builder = SslConnector::builder(SslMethod::tls())?;
376            // disable certificate verification for `require`
377            builder.set_verify(SslVerifyMode::NONE);
378            MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
379        }
380
381        SslMode::VerifyCa | SslMode::VerifyFull => {
382            pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
383            let mut builder = SslConnector::builder(SslMethod::tls())?;
384            if let Some(ssl_root_cert) = ssl_root_cert {
385                builder.set_ca_file(ssl_root_cert).map_err(|e| {
386                    anyhow!(format!("bad ssl root cert error: {}", e.to_report_string()))
387                })?;
388            }
389            let mut connector = MakeTlsConnector::new(builder.build());
390            if !verify_hostname {
391                connector.set_callback(|config, _| {
392                    config.set_verify_hostname(false);
393                    Ok(())
394                });
395            }
396            MaybeMakeTlsConnector::Tls(connector)
397        }
398    };
399    #[cfg(madsim)]
400    let connector = NoTls;
401
402    let (client, connection) = pg_config.connect(connector).await?;
403
404    tokio::spawn(async move {
405        if let Err(e) = connection.await {
406            tracing::error!(error = %e.as_report(), "postgres connection error");
407        }
408    });
409
410    Ok(client)
411}
412
413// Used for both source and sink connector
414pub fn sea_type_to_rw_type(col_type: &SeaType) -> ConnectorResult<DataType> {
415    let dtype = match col_type {
416        SeaType::SmallInt | SeaType::SmallSerial => DataType::Int16,
417        SeaType::Integer | SeaType::Serial => DataType::Int32,
418        SeaType::BigInt | SeaType::BigSerial => DataType::Int64,
419        SeaType::Money | SeaType::Decimal(_) | SeaType::Numeric(_) => DataType::Decimal,
420        SeaType::Real => DataType::Float32,
421        SeaType::DoublePrecision => DataType::Float64,
422        SeaType::Varchar(_) | SeaType::Char(_) | SeaType::Text => DataType::Varchar,
423        SeaType::Bytea => DataType::Bytea,
424        SeaType::Timestamp(_) => DataType::Timestamp,
425        SeaType::TimestampWithTimeZone(_) => DataType::Timestamptz,
426        SeaType::Date => DataType::Date,
427        SeaType::Time(_) | SeaType::TimeWithTimeZone(_) => DataType::Time,
428        SeaType::Interval(_) => DataType::Interval,
429        SeaType::Boolean => DataType::Boolean,
430        SeaType::Point => DataType::Struct(StructType::new(vec![
431            ("x", DataType::Float32),
432            ("y", DataType::Float32),
433        ])),
434        SeaType::Uuid => DataType::Varchar,
435        SeaType::Xml => DataType::Varchar,
436        SeaType::Json => DataType::Jsonb,
437        SeaType::JsonBinary => DataType::Jsonb,
438        SeaType::Array(def) => {
439            let item_type = match def.col_type.as_ref() {
440                Some(ty) => sea_type_to_rw_type(ty.as_ref())?,
441                None => {
442                    return Err(anyhow!("ARRAY type missing element type").into());
443                }
444            };
445
446            DataType::list(item_type)
447        }
448        SeaType::PgLsn => DataType::Int64,
449        SeaType::Cidr
450        | SeaType::Inet
451        | SeaType::MacAddr
452        | SeaType::MacAddr8
453        | SeaType::Int4Range
454        | SeaType::Int8Range
455        | SeaType::NumRange
456        | SeaType::TsRange
457        | SeaType::TsTzRange
458        | SeaType::DateRange
459        | SeaType::Enum(_) => DataType::Varchar,
460        SeaType::Line
461        | SeaType::Lseg
462        | SeaType::Box
463        | SeaType::Path
464        | SeaType::Polygon
465        | SeaType::Circle
466        | SeaType::Bit(_)
467        | SeaType::VarBit(_)
468        | SeaType::TsVector
469        | SeaType::TsQuery => {
470            bail!("{:?} type not supported", col_type);
471        }
472        SeaType::Unknown(name) => {
473            // NOTES: user-defined enum type is classified as `Unknown`
474            tracing::warn!("Unknown Postgres data type: {name}, map to varchar");
475            DataType::Varchar
476        }
477    };
478
479    Ok(dtype)
480}
481
482// Used for sink connector
483// We use `sea-schema` for table schema discovery.
484// So we have to map `sea-schema` pg types
485// to `tokio-postgres` pg types (which we use for query binding).
486fn sea_type_to_pg_type(sea_type: &SeaType) -> ConnectorResult<tokio_postgres::types::Type> {
487    use tokio_postgres::types::Type as PgType;
488    match sea_type {
489        SeaType::SmallInt => Ok(PgType::INT2),
490        SeaType::Integer => Ok(PgType::INT4),
491        SeaType::BigInt => Ok(PgType::INT8),
492        SeaType::Decimal(_) => Ok(PgType::NUMERIC),
493        SeaType::Numeric(_) => Ok(PgType::NUMERIC),
494        SeaType::Real => Ok(PgType::FLOAT4),
495        SeaType::DoublePrecision => Ok(PgType::FLOAT8),
496        SeaType::Varchar(_) => Ok(PgType::VARCHAR),
497        SeaType::Char(_) => Ok(PgType::CHAR),
498        SeaType::Text => Ok(PgType::TEXT),
499        SeaType::Bytea => Ok(PgType::BYTEA),
500        SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP),
501        SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ),
502        SeaType::Date => Ok(PgType::DATE),
503        SeaType::Time(_) => Ok(PgType::TIME),
504        SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ),
505        SeaType::Interval(_) => Ok(PgType::INTERVAL),
506        SeaType::Boolean => Ok(PgType::BOOL),
507        SeaType::Point => Ok(PgType::POINT),
508        SeaType::Uuid => Ok(PgType::UUID),
509        SeaType::Json => Ok(PgType::JSON),
510        SeaType::JsonBinary => Ok(PgType::JSONB),
511        SeaType::Array(t) => {
512            let Some(t) = t.col_type.as_ref() else {
513                bail!("missing array type")
514            };
515            match t.as_ref() {
516                // RW only supports 1 level of nesting.
517                SeaType::SmallInt => Ok(PgType::INT2_ARRAY),
518                SeaType::Integer => Ok(PgType::INT4_ARRAY),
519                SeaType::BigInt => Ok(PgType::INT8_ARRAY),
520                SeaType::Decimal(_) => Ok(PgType::NUMERIC_ARRAY),
521                SeaType::Numeric(_) => Ok(PgType::NUMERIC_ARRAY),
522                SeaType::Real => Ok(PgType::FLOAT4_ARRAY),
523                SeaType::DoublePrecision => Ok(PgType::FLOAT8_ARRAY),
524                SeaType::Varchar(_) => Ok(PgType::VARCHAR_ARRAY),
525                SeaType::Char(_) => Ok(PgType::CHAR_ARRAY),
526                SeaType::Text => Ok(PgType::TEXT_ARRAY),
527                SeaType::Bytea => Ok(PgType::BYTEA_ARRAY),
528                SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP_ARRAY),
529                SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ_ARRAY),
530                SeaType::Date => Ok(PgType::DATE_ARRAY),
531                SeaType::Time(_) => Ok(PgType::TIME_ARRAY),
532                SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ_ARRAY),
533                SeaType::Interval(_) => Ok(PgType::INTERVAL_ARRAY),
534                SeaType::Boolean => Ok(PgType::BOOL_ARRAY),
535                SeaType::Point => Ok(PgType::POINT_ARRAY),
536                SeaType::Uuid => Ok(PgType::UUID_ARRAY),
537                SeaType::Json => Ok(PgType::JSON_ARRAY),
538                SeaType::JsonBinary => Ok(PgType::JSONB_ARRAY),
539                SeaType::Array(_) => bail!("nested array type is not supported"),
540                SeaType::Unknown(name) => {
541                    // Treat as enum type
542                    Ok(PgType::new(
543                        name.clone(),
544                        0,
545                        PgKind::Array(PgType::new(
546                            name.clone(),
547                            0,
548                            PgKind::Enum(vec![]),
549                            "".into(),
550                        )),
551                        "".into(),
552                    ))
553                }
554                _ => bail!("unsupported array type: {:?}", t),
555            }
556        }
557        SeaType::Unknown(name) => {
558            // Treat as enum type
559            Ok(PgType::new(
560                name.clone(),
561                0,
562                PgKind::Enum(vec![]),
563                "".into(),
564            ))
565        }
566        _ => bail!("unsupported type: {:?}", sea_type),
567    }
568}