risingwave_connector/connector_common/
postgres.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt;
17
18use anyhow::anyhow;
19use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
20use postgres_openssl::MakeTlsConnector;
21use risingwave_common::bail;
22use risingwave_common::catalog::{ColumnDesc, ColumnId};
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use sea_schema::postgres::def::{ColumnType as SeaType, TableDef, TableInfo};
25use sea_schema::postgres::discovery::SchemaDiscovery;
26use serde_derive::Deserialize;
27use sqlx::PgPool;
28use sqlx::postgres::{PgConnectOptions, PgSslMode};
29use thiserror_ext::AsReport;
30use tokio_postgres::types::Kind as PgKind;
31use tokio_postgres::{Client as PgClient, NoTls};
32
33#[cfg(not(madsim))]
34use super::maybe_tls_connector::MaybeMakeTlsConnector;
35use crate::error::ConnectorResult;
36
37#[derive(Debug, Clone, PartialEq, Deserialize)]
38#[serde(rename_all = "lowercase")]
39pub enum SslMode {
40    #[serde(alias = "disable")]
41    Disabled,
42    #[serde(alias = "prefer")]
43    Preferred,
44    #[serde(alias = "require")]
45    Required,
46    /// verify that the server is trustworthy by checking the certificate chain
47    /// up to the root certificate stored on the client.
48    #[serde(alias = "verify-ca")]
49    VerifyCa,
50    /// Besides verify the certificate, will also verify that the serverhost name
51    /// matches the name stored in the server certificate.
52    #[serde(alias = "verify-full")]
53    VerifyFull,
54}
55
56impl Default for SslMode {
57    fn default() -> Self {
58        Self::Preferred
59    }
60}
61
62pub struct PostgresExternalTable {
63    column_descs: Vec<ColumnDesc>,
64    pk_names: Vec<String>,
65}
66
67impl PostgresExternalTable {
68    async fn discover_schema(
69        username: &str,
70        password: &str,
71        host: &str,
72        port: u16,
73        database: &str,
74        schema: &str,
75        table: &str,
76        ssl_mode: &SslMode,
77        ssl_root_cert: &Option<String>,
78    ) -> ConnectorResult<TableDef> {
79        let mut options = PgConnectOptions::new()
80            .username(username)
81            .password(password)
82            .host(host)
83            .port(port)
84            .database(database)
85            .ssl_mode(match ssl_mode {
86                SslMode::Disabled => PgSslMode::Disable,
87                SslMode::Preferred => PgSslMode::Prefer,
88                SslMode::Required => PgSslMode::Require,
89                SslMode::VerifyCa => PgSslMode::VerifyCa,
90                SslMode::VerifyFull => PgSslMode::VerifyFull,
91            });
92
93        if *ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull {
94            if let Some(root_cert) = ssl_root_cert {
95                options = options.ssl_root_cert(root_cert.as_str());
96            }
97        }
98
99        let connection = PgPool::connect_with(options).await?;
100        let schema_discovery = SchemaDiscovery::new(connection, schema);
101        // fetch column schema and primary key
102        let empty_map = HashMap::new();
103        let table_schema = schema_discovery
104            .discover_table(
105                TableInfo {
106                    name: table.to_owned(),
107                    of_type: None,
108                },
109                &empty_map,
110            )
111            .await?;
112        Ok(table_schema)
113    }
114
115    pub async fn connect(
116        username: &str,
117        password: &str,
118        host: &str,
119        port: u16,
120        database: &str,
121        schema: &str,
122        table: &str,
123        ssl_mode: &SslMode,
124        ssl_root_cert: &Option<String>,
125        is_append_only: bool,
126    ) -> ConnectorResult<Self> {
127        tracing::debug!("connect to postgres external table");
128        let table_schema = Self::discover_schema(
129            username,
130            password,
131            host,
132            port,
133            database,
134            schema,
135            table,
136            ssl_mode,
137            ssl_root_cert,
138        )
139        .await?;
140        let mut column_descs = vec![];
141        for col in &table_schema.columns {
142            let rw_data_type = sea_type_to_rw_type(&col.col_type)?;
143            let column_desc = if let Some(ref default_expr) = col.default {
144                // parse the value of "column_default" field in information_schema.columns,
145                // non number data type will be stored as "'value'::type"
146                let val_text = default_expr
147                    .0
148                    .split("::")
149                    .map(|s| s.trim_matches('\''))
150                    .next()
151                    .expect("default value expression");
152
153                match ScalarImpl::from_text(val_text, &rw_data_type) {
154                    Ok(scalar) => ColumnDesc::named_with_default_value(
155                        col.name.clone(),
156                        ColumnId::placeholder(),
157                        rw_data_type.clone(),
158                        Some(scalar),
159                    ),
160                    Err(err) => {
161                        tracing::warn!(error=%err.as_report(), "failed to parse postgres default value expression, only constant is supported");
162                        ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
163                    }
164                }
165            } else {
166                ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
167            };
168            column_descs.push(column_desc);
169        }
170
171        if !is_append_only && table_schema.primary_key_constraints.is_empty() {
172            return Err(anyhow!(
173                "Postgres table should define the primary key for non-append-only tables"
174            )
175            .into());
176        }
177        let mut pk_names = vec![];
178        table_schema.primary_key_constraints.iter().for_each(|pk| {
179            pk_names.extend(pk.columns.clone());
180        });
181
182        Ok(Self {
183            column_descs,
184            pk_names,
185        })
186    }
187
188    // return the mapping from column name to pg type, the pg type is used for writing data to postgres
189    pub async fn type_mapping(
190        username: &str,
191        password: &str,
192        host: &str,
193        port: u16,
194        database: &str,
195        schema: &str,
196        table: &str,
197        ssl_mode: &SslMode,
198        ssl_root_cert: &Option<String>,
199        is_append_only: bool,
200    ) -> ConnectorResult<HashMap<String, tokio_postgres::types::Type>> {
201        tracing::debug!("connect to postgres external table to get type mapping");
202        let table_schema = Self::discover_schema(
203            username,
204            password,
205            host,
206            port,
207            database,
208            schema,
209            table,
210            ssl_mode,
211            ssl_root_cert,
212        )
213        .await?;
214        let mut column_name_to_pg_type = HashMap::new();
215        for col in &table_schema.columns {
216            let pg_type = sea_type_to_pg_type(&col.col_type)?;
217            column_name_to_pg_type.insert(col.name.clone(), pg_type);
218        }
219        if !is_append_only && table_schema.primary_key_constraints.is_empty() {
220            return Err(anyhow!(
221                "Postgres table should define the primary key for non-append-only tables"
222            )
223            .into());
224        }
225        Ok(column_name_to_pg_type)
226    }
227
228    pub fn column_descs(&self) -> &Vec<ColumnDesc> {
229        &self.column_descs
230    }
231
232    pub fn pk_names(&self) -> &Vec<String> {
233        &self.pk_names
234    }
235}
236
237impl fmt::Display for SslMode {
238    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239        f.write_str(match self {
240            SslMode::Disabled => "disabled",
241            SslMode::Preferred => "preferred",
242            SslMode::Required => "required",
243            SslMode::VerifyCa => "verify-ca",
244            SslMode::VerifyFull => "verify-full",
245        })
246    }
247}
248
249pub async fn create_pg_client(
250    user: &str,
251    password: &str,
252    host: &str,
253    port: &str,
254    database: &str,
255    ssl_mode: &SslMode,
256    ssl_root_cert: &Option<String>,
257) -> anyhow::Result<PgClient> {
258    let mut pg_config = tokio_postgres::Config::new();
259    pg_config
260        .user(user)
261        .password(password)
262        .host(host)
263        .port(port.parse::<u16>().unwrap())
264        .dbname(database);
265
266    let (_verify_ca, verify_hostname) = match ssl_mode {
267        SslMode::VerifyCa => (true, false),
268        SslMode::VerifyFull => (true, true),
269        _ => (false, false),
270    };
271
272    #[cfg(not(madsim))]
273    let connector = match ssl_mode {
274        SslMode::Disabled => {
275            pg_config.ssl_mode(tokio_postgres::config::SslMode::Disable);
276            MaybeMakeTlsConnector::NoTls(NoTls)
277        }
278        SslMode::Preferred => {
279            pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
280            match SslConnector::builder(SslMethod::tls()) {
281                Ok(mut builder) => {
282                    // disable certificate verification for `prefer`
283                    builder.set_verify(SslVerifyMode::NONE);
284                    MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
285                }
286                Err(e) => {
287                    tracing::warn!(error = %e.as_report(), "SSL connector error");
288                    MaybeMakeTlsConnector::NoTls(NoTls)
289                }
290            }
291        }
292        SslMode::Required => {
293            pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
294            let mut builder = SslConnector::builder(SslMethod::tls())?;
295            // disable certificate verification for `require`
296            builder.set_verify(SslVerifyMode::NONE);
297            MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
298        }
299
300        SslMode::VerifyCa | SslMode::VerifyFull => {
301            pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
302            let mut builder = SslConnector::builder(SslMethod::tls())?;
303            if let Some(ssl_root_cert) = ssl_root_cert {
304                builder.set_ca_file(ssl_root_cert).map_err(|e| {
305                    anyhow!(format!("bad ssl root cert error: {}", e.to_report_string()))
306                })?;
307            }
308            let mut connector = MakeTlsConnector::new(builder.build());
309            if !verify_hostname {
310                connector.set_callback(|config, _| {
311                    config.set_verify_hostname(false);
312                    Ok(())
313                });
314            }
315            MaybeMakeTlsConnector::Tls(connector)
316        }
317    };
318    #[cfg(madsim)]
319    let connector = NoTls;
320
321    let (client, connection) = pg_config.connect(connector).await?;
322
323    tokio::spawn(async move {
324        if let Err(e) = connection.await {
325            tracing::error!(error = %e.as_report(), "postgres connection error");
326        }
327    });
328
329    Ok(client)
330}
331
332// Used for both source and sink connector
333pub fn sea_type_to_rw_type(col_type: &SeaType) -> ConnectorResult<DataType> {
334    let dtype = match col_type {
335        SeaType::SmallInt | SeaType::SmallSerial => DataType::Int16,
336        SeaType::Integer | SeaType::Serial => DataType::Int32,
337        SeaType::BigInt | SeaType::BigSerial => DataType::Int64,
338        SeaType::Money | SeaType::Decimal(_) | SeaType::Numeric(_) => DataType::Decimal,
339        SeaType::Real => DataType::Float32,
340        SeaType::DoublePrecision => DataType::Float64,
341        SeaType::Varchar(_) | SeaType::Char(_) | SeaType::Text => DataType::Varchar,
342        SeaType::Bytea => DataType::Bytea,
343        SeaType::Timestamp(_) => DataType::Timestamp,
344        SeaType::TimestampWithTimeZone(_) => DataType::Timestamptz,
345        SeaType::Date => DataType::Date,
346        SeaType::Time(_) | SeaType::TimeWithTimeZone(_) => DataType::Time,
347        SeaType::Interval(_) => DataType::Interval,
348        SeaType::Boolean => DataType::Boolean,
349        SeaType::Point => DataType::Struct(StructType::new(vec![
350            ("x", DataType::Float32),
351            ("y", DataType::Float32),
352        ])),
353        SeaType::Uuid => DataType::Varchar,
354        SeaType::Xml => DataType::Varchar,
355        SeaType::Json => DataType::Jsonb,
356        SeaType::JsonBinary => DataType::Jsonb,
357        SeaType::Array(def) => {
358            let item_type = match def.col_type.as_ref() {
359                Some(ty) => sea_type_to_rw_type(ty.as_ref())?,
360                None => {
361                    return Err(anyhow!("ARRAY type missing element type").into());
362                }
363            };
364
365            DataType::List(Box::new(item_type))
366        }
367        SeaType::PgLsn => DataType::Int64,
368        SeaType::Cidr
369        | SeaType::Inet
370        | SeaType::MacAddr
371        | SeaType::MacAddr8
372        | SeaType::Int4Range
373        | SeaType::Int8Range
374        | SeaType::NumRange
375        | SeaType::TsRange
376        | SeaType::TsTzRange
377        | SeaType::DateRange
378        | SeaType::Enum(_) => DataType::Varchar,
379        SeaType::Line
380        | SeaType::Lseg
381        | SeaType::Box
382        | SeaType::Path
383        | SeaType::Polygon
384        | SeaType::Circle
385        | SeaType::Bit(_)
386        | SeaType::VarBit(_)
387        | SeaType::TsVector
388        | SeaType::TsQuery => {
389            bail!("{:?} type not supported", col_type);
390        }
391        SeaType::Unknown(name) => {
392            // NOTES: user-defined enum type is classified as `Unknown`
393            tracing::warn!("Unknown Postgres data type: {name}, map to varchar");
394            DataType::Varchar
395        }
396    };
397
398    Ok(dtype)
399}
400
401// Used for sink connector
402// We use `sea-schema` for table schema discovery.
403// So we have to map `sea-schema` pg types
404// to `tokio-postgres` pg types (which we use for query binding).
405fn sea_type_to_pg_type(sea_type: &SeaType) -> ConnectorResult<tokio_postgres::types::Type> {
406    use tokio_postgres::types::Type as PgType;
407    match sea_type {
408        SeaType::SmallInt => Ok(PgType::INT2),
409        SeaType::Integer => Ok(PgType::INT4),
410        SeaType::BigInt => Ok(PgType::INT8),
411        SeaType::Decimal(_) => Ok(PgType::NUMERIC),
412        SeaType::Numeric(_) => Ok(PgType::NUMERIC),
413        SeaType::Real => Ok(PgType::FLOAT4),
414        SeaType::DoublePrecision => Ok(PgType::FLOAT8),
415        SeaType::Varchar(_) => Ok(PgType::VARCHAR),
416        SeaType::Char(_) => Ok(PgType::CHAR),
417        SeaType::Text => Ok(PgType::TEXT),
418        SeaType::Bytea => Ok(PgType::BYTEA),
419        SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP),
420        SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ),
421        SeaType::Date => Ok(PgType::DATE),
422        SeaType::Time(_) => Ok(PgType::TIME),
423        SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ),
424        SeaType::Interval(_) => Ok(PgType::INTERVAL),
425        SeaType::Boolean => Ok(PgType::BOOL),
426        SeaType::Point => Ok(PgType::POINT),
427        SeaType::Uuid => Ok(PgType::UUID),
428        SeaType::Json => Ok(PgType::JSON),
429        SeaType::JsonBinary => Ok(PgType::JSONB),
430        SeaType::Array(t) => {
431            let Some(t) = t.col_type.as_ref() else {
432                bail!("missing array type")
433            };
434            match t.as_ref() {
435                // RW only supports 1 level of nesting.
436                SeaType::SmallInt => Ok(PgType::INT2_ARRAY),
437                SeaType::Integer => Ok(PgType::INT4_ARRAY),
438                SeaType::BigInt => Ok(PgType::INT8_ARRAY),
439                SeaType::Decimal(_) => Ok(PgType::NUMERIC_ARRAY),
440                SeaType::Numeric(_) => Ok(PgType::NUMERIC_ARRAY),
441                SeaType::Real => Ok(PgType::FLOAT4_ARRAY),
442                SeaType::DoublePrecision => Ok(PgType::FLOAT8_ARRAY),
443                SeaType::Varchar(_) => Ok(PgType::VARCHAR_ARRAY),
444                SeaType::Char(_) => Ok(PgType::CHAR_ARRAY),
445                SeaType::Text => Ok(PgType::TEXT_ARRAY),
446                SeaType::Bytea => Ok(PgType::BYTEA_ARRAY),
447                SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP_ARRAY),
448                SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ_ARRAY),
449                SeaType::Date => Ok(PgType::DATE_ARRAY),
450                SeaType::Time(_) => Ok(PgType::TIME_ARRAY),
451                SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ_ARRAY),
452                SeaType::Interval(_) => Ok(PgType::INTERVAL_ARRAY),
453                SeaType::Boolean => Ok(PgType::BOOL_ARRAY),
454                SeaType::Point => Ok(PgType::POINT_ARRAY),
455                SeaType::Uuid => Ok(PgType::UUID_ARRAY),
456                SeaType::Json => Ok(PgType::JSON_ARRAY),
457                SeaType::JsonBinary => Ok(PgType::JSONB_ARRAY),
458                SeaType::Array(_) => bail!("nested array type is not supported"),
459                SeaType::Unknown(name) => {
460                    // Treat as enum type
461                    Ok(PgType::new(
462                        name.clone(),
463                        0,
464                        PgKind::Array(PgType::new(
465                            name.clone(),
466                            0,
467                            PgKind::Enum(vec![]),
468                            "".into(),
469                        )),
470                        "".into(),
471                    ))
472                }
473                _ => bail!("unsupported array type: {:?}", t),
474            }
475        }
476        SeaType::Unknown(name) => {
477            // Treat as enum type
478            Ok(PgType::new(
479                name.clone(),
480                0,
481                PgKind::Enum(vec![]),
482                "".into(),
483            ))
484        }
485        _ => bail!("unsupported type: {:?}", sea_type),
486    }
487}