risingwave_connector/connector_common/
postgres.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt;
17
18use anyhow::{Context, anyhow};
19use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
20use postgres_openssl::MakeTlsConnector;
21use risingwave_common::bail;
22use risingwave_common::catalog::{ColumnDesc, ColumnId};
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use sea_schema::postgres::def::{ColumnType as SeaType, TableDef, TableInfo};
25use sea_schema::postgres::discovery::SchemaDiscovery;
26use sea_schema::sea_query::{Alias, IntoIden};
27use serde::Deserialize;
28use sqlx::postgres::{PgConnectOptions, PgSslMode};
29use sqlx::{PgPool, Row};
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Kind as PgKind;
32use tokio_postgres::{Client as PgClient, NoTls};
33
34#[cfg(not(madsim))]
35use super::maybe_tls_connector::MaybeMakeTlsConnector;
36use crate::error::ConnectorResult;
37
38/// SQL query to discover primary key columns directly from PostgreSQL system tables.
39/// This bypasses querying `information_schema.table_constraints` to avoid permission issues.
40const DISCOVER_PRIMARY_KEY_QUERY: &str = r#"
41    SELECT a.attname as column_name
42    FROM pg_index i
43    JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
44    WHERE i.indrelid = ($1 || '.' || $2)::regclass
45      AND i.indisprimary = true
46    ORDER BY array_position(i.indkey, a.attnum)
47"#;
48
49#[derive(Debug, Clone, PartialEq, Deserialize)]
50#[serde(rename_all = "lowercase")]
51pub enum SslMode {
52    #[serde(alias = "disable")]
53    Disabled,
54    #[serde(alias = "prefer")]
55    Preferred,
56    #[serde(alias = "require")]
57    Required,
58    /// verify that the server is trustworthy by checking the certificate chain
59    /// up to the root certificate stored on the client.
60    #[serde(alias = "verify-ca")]
61    VerifyCa,
62    /// Besides verify the certificate, will also verify that the serverhost name
63    /// matches the name stored in the server certificate.
64    #[serde(alias = "verify-full")]
65    VerifyFull,
66}
67
68impl Default for SslMode {
69    fn default() -> Self {
70        Self::Preferred
71    }
72}
73
74pub struct PostgresExternalTable {
75    column_descs: Vec<ColumnDesc>,
76    pk_names: Vec<String>,
77}
78
79impl PostgresExternalTable {
80    /// Discover primary key columns directly from PostgreSQL system tables.
81    /// This bypasses querying `information_schema.table_constraints` to avoid requiring table owner permissions.
82    async fn discover_primary_key(
83        connection: &PgPool,
84        schema_name: &str,
85        table_name: &str,
86    ) -> ConnectorResult<Vec<String>> {
87        let rows = sqlx::query(DISCOVER_PRIMARY_KEY_QUERY)
88            .bind(schema_name)
89            .bind(table_name)
90            .fetch_all(connection)
91            .await
92            .context("Failed to discover primary key columns")?;
93
94        let pk_columns = rows
95            .into_iter()
96            .map(|row| row.get::<String, _>("column_name"))
97            .collect();
98
99        Ok(pk_columns)
100    }
101
102    /// Discover schema with workaround for primary key discovery
103    /// This method uses direct PostgreSQL system table queries for primary keys
104    /// to avoid permission issues when querying `information_schema.table_constraints`
105    async fn discover_pk_and_full_columns(
106        username: &str,
107        password: &str,
108        host: &str,
109        port: u16,
110        database: &str,
111        schema: &str,
112        table: &str,
113        ssl_mode: &SslMode,
114        ssl_root_cert: &Option<String>,
115    ) -> ConnectorResult<(Vec<sea_schema::postgres::def::ColumnInfo>, Vec<String>)> {
116        let mut options = PgConnectOptions::new()
117            .username(username)
118            .password(password)
119            .host(host)
120            .port(port)
121            .database(database)
122            .ssl_mode(match ssl_mode {
123                SslMode::Disabled => PgSslMode::Disable,
124                SslMode::Preferred => PgSslMode::Prefer,
125                SslMode::Required => PgSslMode::Require,
126                SslMode::VerifyCa => PgSslMode::VerifyCa,
127                SslMode::VerifyFull => PgSslMode::VerifyFull,
128            });
129
130        if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
131            && let Some(root_cert) = ssl_root_cert
132        {
133            options = options.ssl_root_cert(root_cert.as_str());
134        }
135
136        let connection = PgPool::connect_with(options).await?;
137
138        // Use sea-schema only for column discovery (no permission issues)
139        let schema_discovery = SchemaDiscovery::new(connection.clone(), schema);
140        let empty_map: HashMap<String, Vec<String>> = HashMap::new();
141        let columns = schema_discovery
142            .discover_columns(
143                Alias::new(schema).into_iden(),
144                Alias::new(table).into_iden(),
145                &empty_map,
146            )
147            .await?;
148
149        // Use direct system table query for primary key discovery
150        let pk_columns = Self::discover_primary_key(&connection, schema, table).await?;
151
152        Ok((columns, pk_columns))
153    }
154
155    async fn discover_schema(
156        username: &str,
157        password: &str,
158        host: &str,
159        port: u16,
160        database: &str,
161        schema: &str,
162        table: &str,
163        ssl_mode: &SslMode,
164        ssl_root_cert: &Option<String>,
165    ) -> ConnectorResult<TableDef> {
166        let mut options = PgConnectOptions::new()
167            .username(username)
168            .password(password)
169            .host(host)
170            .port(port)
171            .database(database)
172            .ssl_mode(match ssl_mode {
173                SslMode::Disabled => PgSslMode::Disable,
174                SslMode::Preferred => PgSslMode::Prefer,
175                SslMode::Required => PgSslMode::Require,
176                SslMode::VerifyCa => PgSslMode::VerifyCa,
177                SslMode::VerifyFull => PgSslMode::VerifyFull,
178            });
179
180        if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
181            && let Some(root_cert) = ssl_root_cert
182        {
183            options = options.ssl_root_cert(root_cert.as_str());
184        }
185
186        let connection = PgPool::connect_with(options).await?;
187        let schema_discovery = SchemaDiscovery::new(connection, schema);
188        // fetch column schema and primary key
189        let empty_map = HashMap::new();
190        let table_schema = schema_discovery
191            .discover_table(
192                TableInfo {
193                    name: table.to_owned(),
194                    of_type: None,
195                },
196                &empty_map,
197            )
198            .await?;
199        Ok(table_schema)
200    }
201
202    pub async fn connect(
203        username: &str,
204        password: &str,
205        host: &str,
206        port: u16,
207        database: &str,
208        schema: &str,
209        table: &str,
210        ssl_mode: &SslMode,
211        ssl_root_cert: &Option<String>,
212        is_append_only: bool,
213    ) -> ConnectorResult<Self> {
214        tracing::debug!("connect to postgres external table");
215
216        let (columns, pk_names) = Self::discover_pk_and_full_columns(
217            username,
218            password,
219            host,
220            port,
221            database,
222            schema,
223            table,
224            ssl_mode,
225            ssl_root_cert,
226        )
227        .await?;
228
229        let mut column_descs = vec![];
230        for col in &columns {
231            let rw_data_type = sea_type_to_rw_type(&col.col_type)?;
232            let column_desc = if let Some(ref default_expr) = col.default {
233                // parse the value of "column_default" field in information_schema.columns,
234                // non number data type will be stored as "'value'::type"
235                let val_text = default_expr
236                    .0
237                    .split("::")
238                    .map(|s| s.trim_matches('\''))
239                    .next()
240                    .expect("default value expression");
241
242                match ScalarImpl::from_text(val_text, &rw_data_type) {
243                    Ok(scalar) => ColumnDesc::named_with_default_value(
244                        col.name.clone(),
245                        ColumnId::placeholder(),
246                        rw_data_type.clone(),
247                        Some(scalar),
248                    ),
249                    Err(err) => {
250                        tracing::warn!(error=%err.as_report(), "failed to parse postgres default value expression, only constant is supported");
251                        ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
252                    }
253                }
254            } else {
255                ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
256            };
257            column_descs.push(column_desc);
258        }
259
260        // Check primary key existence using the directly discovered pk_names
261        if !is_append_only && pk_names.is_empty() {
262            return Err(anyhow!(
263                "Postgres table should define the primary key for non-append-only tables"
264            )
265            .into());
266        }
267
268        Ok(Self {
269            column_descs,
270            pk_names,
271        })
272    }
273
274    // return the mapping from column name to pg type, the pg type is used for writing data to postgres
275    pub async fn type_mapping(
276        username: &str,
277        password: &str,
278        host: &str,
279        port: u16,
280        database: &str,
281        schema: &str,
282        table: &str,
283        ssl_mode: &SslMode,
284        ssl_root_cert: &Option<String>,
285        is_append_only: bool,
286    ) -> ConnectorResult<HashMap<String, tokio_postgres::types::Type>> {
287        tracing::debug!("connect to postgres external table to get type mapping");
288        let table_schema = Self::discover_schema(
289            username,
290            password,
291            host,
292            port,
293            database,
294            schema,
295            table,
296            ssl_mode,
297            ssl_root_cert,
298        )
299        .await?;
300        let mut column_name_to_pg_type = HashMap::new();
301        for col in &table_schema.columns {
302            let pg_type = sea_type_to_pg_type(&col.col_type)?;
303            column_name_to_pg_type.insert(col.name.clone(), pg_type);
304        }
305        if !is_append_only && table_schema.primary_key_constraints.is_empty() {
306            return Err(anyhow!(
307                "Postgres table should define the primary key for non-append-only tables"
308            )
309            .into());
310        }
311        Ok(column_name_to_pg_type)
312    }
313
314    pub fn column_descs(&self) -> &Vec<ColumnDesc> {
315        &self.column_descs
316    }
317
318    pub fn pk_names(&self) -> &Vec<String> {
319        &self.pk_names
320    }
321}
322
323impl fmt::Display for SslMode {
324    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
325        f.write_str(match self {
326            SslMode::Disabled => "disabled",
327            SslMode::Preferred => "preferred",
328            SslMode::Required => "required",
329            SslMode::VerifyCa => "verify-ca",
330            SslMode::VerifyFull => "verify-full",
331        })
332    }
333}
334
335pub async fn create_pg_client(
336    user: &str,
337    password: &str,
338    host: &str,
339    port: &str,
340    database: &str,
341    ssl_mode: &SslMode,
342    ssl_root_cert: &Option<String>,
343) -> anyhow::Result<PgClient> {
344    let mut pg_config = tokio_postgres::Config::new();
345    pg_config
346        .user(user)
347        .password(password)
348        .host(host)
349        .port(port.parse::<u16>().unwrap())
350        .dbname(database);
351
352    let (_verify_ca, verify_hostname) = match ssl_mode {
353        SslMode::VerifyCa => (true, false),
354        SslMode::VerifyFull => (true, true),
355        _ => (false, false),
356    };
357
358    #[cfg(not(madsim))]
359    let connector = match ssl_mode {
360        SslMode::Disabled => {
361            pg_config.ssl_mode(tokio_postgres::config::SslMode::Disable);
362            MaybeMakeTlsConnector::NoTls(NoTls)
363        }
364        SslMode::Preferred => {
365            pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
366            match SslConnector::builder(SslMethod::tls()) {
367                Ok(mut builder) => {
368                    // disable certificate verification for `prefer`
369                    builder.set_verify(SslVerifyMode::NONE);
370                    MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
371                }
372                Err(e) => {
373                    tracing::warn!(error = %e.as_report(), "SSL connector error");
374                    MaybeMakeTlsConnector::NoTls(NoTls)
375                }
376            }
377        }
378        SslMode::Required => {
379            pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
380            let mut builder = SslConnector::builder(SslMethod::tls())?;
381            // disable certificate verification for `require`
382            builder.set_verify(SslVerifyMode::NONE);
383            MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
384        }
385
386        SslMode::VerifyCa | SslMode::VerifyFull => {
387            pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
388            let mut builder = SslConnector::builder(SslMethod::tls())?;
389            if let Some(ssl_root_cert) = ssl_root_cert {
390                builder.set_ca_file(ssl_root_cert).map_err(|e| {
391                    anyhow!(format!("bad ssl root cert error: {}", e.to_report_string()))
392                })?;
393            }
394            let mut connector = MakeTlsConnector::new(builder.build());
395            if !verify_hostname {
396                connector.set_callback(|config, _| {
397                    config.set_verify_hostname(false);
398                    Ok(())
399                });
400            }
401            MaybeMakeTlsConnector::Tls(connector)
402        }
403    };
404    #[cfg(madsim)]
405    let connector = NoTls;
406
407    let (client, connection) = pg_config.connect(connector).await?;
408
409    tokio::spawn(async move {
410        if let Err(e) = connection.await {
411            tracing::error!(error = %e.as_report(), "postgres connection error");
412        }
413    });
414
415    Ok(client)
416}
417
418// Used for both source and sink connector
419pub fn sea_type_to_rw_type(col_type: &SeaType) -> ConnectorResult<DataType> {
420    let dtype = match col_type {
421        SeaType::SmallInt | SeaType::SmallSerial => DataType::Int16,
422        SeaType::Integer | SeaType::Serial => DataType::Int32,
423        SeaType::BigInt | SeaType::BigSerial => DataType::Int64,
424        SeaType::Money | SeaType::Decimal(_) | SeaType::Numeric(_) => DataType::Decimal,
425        SeaType::Real => DataType::Float32,
426        SeaType::DoublePrecision => DataType::Float64,
427        SeaType::Varchar(_) | SeaType::Char(_) | SeaType::Text => DataType::Varchar,
428        SeaType::Bytea => DataType::Bytea,
429        SeaType::Timestamp(_) => DataType::Timestamp,
430        SeaType::TimestampWithTimeZone(_) => DataType::Timestamptz,
431        SeaType::Date => DataType::Date,
432        SeaType::Time(_) | SeaType::TimeWithTimeZone(_) => DataType::Time,
433        SeaType::Interval(_) => DataType::Interval,
434        SeaType::Boolean => DataType::Boolean,
435        SeaType::Point => DataType::Struct(StructType::new(vec![
436            ("x", DataType::Float32),
437            ("y", DataType::Float32),
438        ])),
439        SeaType::Uuid => DataType::Varchar,
440        SeaType::Xml => DataType::Varchar,
441        SeaType::Json => DataType::Jsonb,
442        SeaType::JsonBinary => DataType::Jsonb,
443        SeaType::Array(def) => {
444            let item_type = match def.col_type.as_ref() {
445                Some(ty) => sea_type_to_rw_type(ty.as_ref())?,
446                None => {
447                    return Err(anyhow!("ARRAY type missing element type").into());
448                }
449            };
450
451            DataType::List(Box::new(item_type))
452        }
453        SeaType::PgLsn => DataType::Int64,
454        SeaType::Cidr
455        | SeaType::Inet
456        | SeaType::MacAddr
457        | SeaType::MacAddr8
458        | SeaType::Int4Range
459        | SeaType::Int8Range
460        | SeaType::NumRange
461        | SeaType::TsRange
462        | SeaType::TsTzRange
463        | SeaType::DateRange
464        | SeaType::Enum(_) => DataType::Varchar,
465        SeaType::Line
466        | SeaType::Lseg
467        | SeaType::Box
468        | SeaType::Path
469        | SeaType::Polygon
470        | SeaType::Circle
471        | SeaType::Bit(_)
472        | SeaType::VarBit(_)
473        | SeaType::TsVector
474        | SeaType::TsQuery => {
475            bail!("{:?} type not supported", col_type);
476        }
477        SeaType::Unknown(name) => {
478            // NOTES: user-defined enum type is classified as `Unknown`
479            tracing::warn!("Unknown Postgres data type: {name}, map to varchar");
480            DataType::Varchar
481        }
482    };
483
484    Ok(dtype)
485}
486
487// Used for sink connector
488// We use `sea-schema` for table schema discovery.
489// So we have to map `sea-schema` pg types
490// to `tokio-postgres` pg types (which we use for query binding).
491fn sea_type_to_pg_type(sea_type: &SeaType) -> ConnectorResult<tokio_postgres::types::Type> {
492    use tokio_postgres::types::Type as PgType;
493    match sea_type {
494        SeaType::SmallInt => Ok(PgType::INT2),
495        SeaType::Integer => Ok(PgType::INT4),
496        SeaType::BigInt => Ok(PgType::INT8),
497        SeaType::Decimal(_) => Ok(PgType::NUMERIC),
498        SeaType::Numeric(_) => Ok(PgType::NUMERIC),
499        SeaType::Real => Ok(PgType::FLOAT4),
500        SeaType::DoublePrecision => Ok(PgType::FLOAT8),
501        SeaType::Varchar(_) => Ok(PgType::VARCHAR),
502        SeaType::Char(_) => Ok(PgType::CHAR),
503        SeaType::Text => Ok(PgType::TEXT),
504        SeaType::Bytea => Ok(PgType::BYTEA),
505        SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP),
506        SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ),
507        SeaType::Date => Ok(PgType::DATE),
508        SeaType::Time(_) => Ok(PgType::TIME),
509        SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ),
510        SeaType::Interval(_) => Ok(PgType::INTERVAL),
511        SeaType::Boolean => Ok(PgType::BOOL),
512        SeaType::Point => Ok(PgType::POINT),
513        SeaType::Uuid => Ok(PgType::UUID),
514        SeaType::Json => Ok(PgType::JSON),
515        SeaType::JsonBinary => Ok(PgType::JSONB),
516        SeaType::Array(t) => {
517            let Some(t) = t.col_type.as_ref() else {
518                bail!("missing array type")
519            };
520            match t.as_ref() {
521                // RW only supports 1 level of nesting.
522                SeaType::SmallInt => Ok(PgType::INT2_ARRAY),
523                SeaType::Integer => Ok(PgType::INT4_ARRAY),
524                SeaType::BigInt => Ok(PgType::INT8_ARRAY),
525                SeaType::Decimal(_) => Ok(PgType::NUMERIC_ARRAY),
526                SeaType::Numeric(_) => Ok(PgType::NUMERIC_ARRAY),
527                SeaType::Real => Ok(PgType::FLOAT4_ARRAY),
528                SeaType::DoublePrecision => Ok(PgType::FLOAT8_ARRAY),
529                SeaType::Varchar(_) => Ok(PgType::VARCHAR_ARRAY),
530                SeaType::Char(_) => Ok(PgType::CHAR_ARRAY),
531                SeaType::Text => Ok(PgType::TEXT_ARRAY),
532                SeaType::Bytea => Ok(PgType::BYTEA_ARRAY),
533                SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP_ARRAY),
534                SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ_ARRAY),
535                SeaType::Date => Ok(PgType::DATE_ARRAY),
536                SeaType::Time(_) => Ok(PgType::TIME_ARRAY),
537                SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ_ARRAY),
538                SeaType::Interval(_) => Ok(PgType::INTERVAL_ARRAY),
539                SeaType::Boolean => Ok(PgType::BOOL_ARRAY),
540                SeaType::Point => Ok(PgType::POINT_ARRAY),
541                SeaType::Uuid => Ok(PgType::UUID_ARRAY),
542                SeaType::Json => Ok(PgType::JSON_ARRAY),
543                SeaType::JsonBinary => Ok(PgType::JSONB_ARRAY),
544                SeaType::Array(_) => bail!("nested array type is not supported"),
545                SeaType::Unknown(name) => {
546                    // Treat as enum type
547                    Ok(PgType::new(
548                        name.clone(),
549                        0,
550                        PgKind::Array(PgType::new(
551                            name.clone(),
552                            0,
553                            PgKind::Enum(vec![]),
554                            "".into(),
555                        )),
556                        "".into(),
557                    ))
558                }
559                _ => bail!("unsupported array type: {:?}", t),
560            }
561        }
562        SeaType::Unknown(name) => {
563            // Treat as enum type
564            Ok(PgType::new(
565                name.clone(),
566                0,
567                PgKind::Enum(vec![]),
568                "".into(),
569            ))
570        }
571        _ => bail!("unsupported type: {:?}", sea_type),
572    }
573}