risingwave_connector/connector_common/
postgres.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::fmt;
17
18use anyhow::{Context, anyhow};
19use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
20use postgres_openssl::MakeTlsConnector;
21use risingwave_common::bail;
22use risingwave_common::catalog::{ColumnDesc, ColumnId};
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use sea_schema::postgres::def::{ColumnType as SeaType, TableDef, TableInfo};
25use sea_schema::postgres::discovery::SchemaDiscovery;
26use sea_schema::sea_query::{Alias, IntoIden};
27use serde::Deserialize;
28use sqlx::postgres::{PgConnectOptions, PgSslMode};
29use sqlx::{PgPool, Row};
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Kind as PgKind;
32use tokio_postgres::{Client as PgClient, NoTls};
33
34#[cfg(not(madsim))]
35use super::maybe_tls_connector::MaybeMakeTlsConnector;
36use crate::error::ConnectorResult;
37use crate::sink::postgres::TcpKeepaliveConfig;
38
39/// SQL query to discover primary key columns directly from PostgreSQL system tables.
40/// This bypasses querying `information_schema.table_constraints` to avoid permission issues.
41const DISCOVER_PRIMARY_KEY_QUERY: &str = r#"
42    SELECT a.attname as column_name
43    FROM pg_index i
44    JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
45    WHERE i.indrelid = ($1 || '.' || $2)::regclass
46      AND i.indisprimary = true
47    ORDER BY array_position(i.indkey, a.attnum)
48"#;
49
50/// Discover pgvector columns with both `atttypmod` (dimension) and `format_type` text.
51/// `vector(n)` is stored as `atttypmod = n`, while dimension-less `vector` uses `-1`.
52/// We rely on this to keep user-defined type modifiers that are not preserved by sea-schema.
53const DISCOVER_PGVECTOR_COLUMNS_QUERY: &str = r#"
54    SELECT
55      a.attname as column_name,
56      a.atttypmod as atttypmod,
57      format_type(a.atttypid, a.atttypmod) as formatted_type
58    FROM pg_attribute a
59    JOIN pg_class c ON c.oid = a.attrelid
60    JOIN pg_namespace n ON n.oid = c.relnamespace
61    JOIN pg_type t ON t.oid = a.atttypid
62    WHERE n.nspname = $1
63      AND c.relname = $2
64      AND t.typname = 'vector'
65      AND a.attnum > 0
66      AND NOT a.attisdropped
67    ORDER BY a.attnum
68"#;
69
70pub struct PgConnectionConfig {
71    pub host: String,
72    pub port: String,
73    pub user: String,
74    pub password: String,
75    pub database: String,
76    pub ssl_mode: SslMode,
77    pub ssl_root_cert: Option<String>,
78}
79
80pub fn pg_connection_config_from_properties(
81    props: &BTreeMap<String, String>,
82) -> ConnectorResult<PgConnectionConfig> {
83    Ok(PgConnectionConfig {
84        host: props
85            .get("hostname")
86            .context("missing `hostname` in postgres-cdc properties")?
87            .clone(),
88        port: props
89            .get("port")
90            .context("missing `port` in postgres-cdc properties")?
91            .clone(),
92        user: props
93            .get("username")
94            .context("missing `username` in postgres-cdc properties")?
95            .clone(),
96        password: props.get("password").cloned().unwrap_or_default(),
97        database: props
98            .get("database.name")
99            .context("missing `database.name` in postgres-cdc properties")?
100            .clone(),
101        ssl_mode: props
102            .get("ssl.mode")
103            .and_then(|v| v.parse::<SslMode>().ok())
104            .unwrap_or_default(),
105        ssl_root_cert: props.get("ssl.root.cert").cloned(),
106    })
107}
108
109pub async fn create_pg_client_from_properties(
110    props: &BTreeMap<String, String>,
111    tcp_keepalive: Option<TcpKeepaliveConfig>,
112) -> ConnectorResult<PgClient> {
113    let config = pg_connection_config_from_properties(props)?;
114    create_pg_client(
115        &config.user,
116        &config.password,
117        &config.host,
118        &config.port,
119        &config.database,
120        &config.ssl_mode,
121        &config.ssl_root_cert,
122        tcp_keepalive,
123    )
124    .await
125    .map_err(Into::into)
126}
127
128pub async fn discover_pgvector_dimensions(
129    client: &PgClient,
130    schema: &str,
131    table: &str,
132) -> ConnectorResult<HashMap<String, usize>> {
133    let rows = client
134        .query(DISCOVER_PGVECTOR_COLUMNS_QUERY, &[&schema, &table])
135        .await?;
136
137    let mut dims = HashMap::new();
138    for row in rows {
139        let col_name: String = row.get("column_name");
140        let atttypmod: i32 = row.get("atttypmod");
141        if atttypmod > 0
142            && let Ok(dim) = usize::try_from(atttypmod)
143        {
144            dims.insert(col_name, dim);
145        }
146    }
147    Ok(dims)
148}
149
150#[derive(Debug, Clone, PartialEq, Deserialize, Default)]
151#[serde(rename_all = "lowercase")]
152pub enum SslMode {
153    #[serde(alias = "disable")]
154    Disabled,
155    #[serde(alias = "prefer")]
156    #[default]
157    Preferred,
158    #[serde(alias = "require")]
159    Required,
160    /// verify that the server is trustworthy by checking the certificate chain
161    /// up to the root certificate stored on the client.
162    #[serde(alias = "verify-ca")]
163    VerifyCa,
164    /// Besides verify the certificate, will also verify that the serverhost name
165    /// matches the name stored in the server certificate.
166    #[serde(alias = "verify-full")]
167    VerifyFull,
168}
169
170pub struct PostgresExternalTable {
171    column_descs: Vec<ColumnDesc>,
172    pk_names: Vec<String>,
173}
174
175impl PostgresExternalTable {
176    /// Discover primary key columns directly from PostgreSQL system tables.
177    /// This bypasses querying `information_schema.table_constraints` to avoid requiring table owner permissions.
178    async fn discover_primary_key(
179        connection: &PgPool,
180        schema_name: &str,
181        table_name: &str,
182    ) -> ConnectorResult<Vec<String>> {
183        let rows = sqlx::query(DISCOVER_PRIMARY_KEY_QUERY)
184            .bind(schema_name)
185            .bind(table_name)
186            .fetch_all(connection)
187            .await
188            .context("Failed to discover primary key columns")?;
189
190        let pk_columns = rows
191            .into_iter()
192            .map(|row| row.get::<String, _>("column_name"))
193            .collect();
194
195        Ok(pk_columns)
196    }
197
198    /// Discover schema with workaround for primary key discovery
199    /// This method uses direct PostgreSQL system table queries for primary keys
200    /// to avoid permission issues when querying `information_schema.table_constraints`
201    async fn discover_pk_and_full_columns(
202        username: &str,
203        password: &str,
204        host: &str,
205        port: u16,
206        database: &str,
207        schema: &str,
208        table: &str,
209        ssl_mode: &SslMode,
210        ssl_root_cert: &Option<String>,
211    ) -> ConnectorResult<(Vec<sea_schema::postgres::def::ColumnInfo>, Vec<String>)> {
212        let mut options = PgConnectOptions::new()
213            .username(username)
214            .password(password)
215            .host(host)
216            .port(port)
217            .database(database)
218            .ssl_mode(match ssl_mode {
219                SslMode::Disabled => PgSslMode::Disable,
220                SslMode::Preferred => PgSslMode::Prefer,
221                SslMode::Required => PgSslMode::Require,
222                SslMode::VerifyCa => PgSslMode::VerifyCa,
223                SslMode::VerifyFull => PgSslMode::VerifyFull,
224            });
225
226        if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
227            && let Some(root_cert) = ssl_root_cert
228        {
229            options = options.ssl_root_cert(root_cert.as_str());
230        }
231
232        let connection = PgPool::connect_with(options).await?;
233
234        // Use sea-schema only for column discovery (no permission issues)
235        let schema_discovery = SchemaDiscovery::new(connection.clone(), schema);
236        let empty_map: HashMap<String, Vec<String>> = HashMap::new();
237        let columns = schema_discovery
238            .discover_columns(
239                Alias::new(schema).into_iden(),
240                Alias::new(table).into_iden(),
241                &empty_map,
242            )
243            .await?;
244
245        let pgvector_columns = sqlx::query(DISCOVER_PGVECTOR_COLUMNS_QUERY)
246            .bind(schema)
247            .bind(table)
248            .fetch_all(&connection)
249            .await
250            .context("Failed to discover PostgreSQL pgvector columns")?;
251        let formatted_type_by_column: HashMap<String, String> = pgvector_columns
252            .into_iter()
253            .map(|row| {
254                (
255                    row.get::<String, _>("column_name"),
256                    row.get::<String, _>("formatted_type"),
257                )
258            })
259            .collect();
260
261        // sea-schema reports pgvector as `Unknown("vector")` and drops the dimension.
262        // Patch it with PostgreSQL's formatted type text so we can derive vector(n).
263        let mut columns = columns;
264        for col in &mut columns {
265            if let SeaType::Unknown(name) = &col.col_type
266                && name.eq_ignore_ascii_case("vector")
267                && let Some(formatted_type) = formatted_type_by_column.get(&col.name)
268            {
269                col.col_type = SeaType::Unknown(formatted_type.clone());
270            }
271        }
272
273        // Use direct system table query for primary key discovery
274        let pk_columns = Self::discover_primary_key(&connection, schema, table).await?;
275
276        Ok((columns, pk_columns))
277    }
278
279    async fn discover_schema(
280        username: &str,
281        password: &str,
282        host: &str,
283        port: u16,
284        database: &str,
285        schema: &str,
286        table: &str,
287        ssl_mode: &SslMode,
288        ssl_root_cert: &Option<String>,
289    ) -> ConnectorResult<TableDef> {
290        let mut options = PgConnectOptions::new()
291            .username(username)
292            .password(password)
293            .host(host)
294            .port(port)
295            .database(database)
296            .ssl_mode(match ssl_mode {
297                SslMode::Disabled => PgSslMode::Disable,
298                SslMode::Preferred => PgSslMode::Prefer,
299                SslMode::Required => PgSslMode::Require,
300                SslMode::VerifyCa => PgSslMode::VerifyCa,
301                SslMode::VerifyFull => PgSslMode::VerifyFull,
302            });
303
304        if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
305            && let Some(root_cert) = ssl_root_cert
306        {
307            options = options.ssl_root_cert(root_cert.as_str());
308        }
309
310        let connection = PgPool::connect_with(options).await?;
311        let schema_discovery = SchemaDiscovery::new(connection, schema);
312        // fetch column schema and primary key
313        let empty_map = HashMap::new();
314        let table_schema = schema_discovery
315            .discover_table(
316                TableInfo {
317                    name: table.to_owned(),
318                    of_type: None,
319                },
320                &empty_map,
321            )
322            .await?;
323        Ok(table_schema)
324    }
325
326    pub async fn connect(
327        username: &str,
328        password: &str,
329        host: &str,
330        port: u16,
331        database: &str,
332        schema: &str,
333        table: &str,
334        ssl_mode: &SslMode,
335        ssl_root_cert: &Option<String>,
336        is_append_only: bool,
337    ) -> ConnectorResult<Self> {
338        tracing::debug!("connect to postgres external table");
339
340        let (columns, pk_names) = Self::discover_pk_and_full_columns(
341            username,
342            password,
343            host,
344            port,
345            database,
346            schema,
347            table,
348            ssl_mode,
349            ssl_root_cert,
350        )
351        .await?;
352
353        let mut column_descs = vec![];
354        for col in &columns {
355            let rw_data_type = sea_type_to_rw_type(&col.col_type)?;
356            let column_desc = if let Some(ref default_expr) = col.default {
357                // parse the value of "column_default" field in information_schema.columns,
358                // non number data type will be stored as "'value'::type"
359                let val_text = default_expr
360                    .0
361                    .split("::")
362                    .map(|s| s.trim_matches('\''))
363                    .next()
364                    .expect("default value expression");
365
366                match ScalarImpl::from_text(val_text, &rw_data_type) {
367                    Ok(scalar) => ColumnDesc::named_with_default_value(
368                        col.name.clone(),
369                        ColumnId::placeholder(),
370                        rw_data_type.clone(),
371                        Some(scalar),
372                    ),
373                    Err(err) => {
374                        tracing::warn!(error=%err.as_report(), "failed to parse postgres default value expression, only constant is supported");
375                        ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
376                    }
377                }
378            } else {
379                ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
380            };
381            column_descs.push(column_desc);
382        }
383
384        // Check primary key existence using the directly discovered pk_names
385        if !is_append_only && pk_names.is_empty() {
386            return Err(anyhow!(
387                "Postgres table should define the primary key for non-append-only tables"
388            )
389            .into());
390        }
391
392        Ok(Self {
393            column_descs,
394            pk_names,
395        })
396    }
397
398    // return the mapping from column name to pg type, the pg type is used for writing data to postgres
399    pub async fn type_mapping(
400        username: &str,
401        password: &str,
402        host: &str,
403        port: u16,
404        database: &str,
405        schema: &str,
406        table: &str,
407        ssl_mode: &SslMode,
408        ssl_root_cert: &Option<String>,
409        is_append_only: bool,
410    ) -> ConnectorResult<HashMap<String, tokio_postgres::types::Type>> {
411        tracing::debug!("connect to postgres external table to get type mapping");
412        let table_schema = Self::discover_schema(
413            username,
414            password,
415            host,
416            port,
417            database,
418            schema,
419            table,
420            ssl_mode,
421            ssl_root_cert,
422        )
423        .await?;
424        let mut column_name_to_pg_type = HashMap::new();
425        for col in &table_schema.columns {
426            let pg_type = sea_type_to_pg_type(&col.col_type)?;
427            column_name_to_pg_type.insert(col.name.clone(), pg_type);
428        }
429        if !is_append_only && table_schema.primary_key_constraints.is_empty() {
430            return Err(anyhow!(
431                "Postgres table should define the primary key for non-append-only tables"
432            )
433            .into());
434        }
435        Ok(column_name_to_pg_type)
436    }
437
438    pub fn column_descs(&self) -> &Vec<ColumnDesc> {
439        &self.column_descs
440    }
441
442    pub fn pk_names(&self) -> &Vec<String> {
443        &self.pk_names
444    }
445}
446
447impl fmt::Display for SslMode {
448    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449        f.write_str(match self {
450            SslMode::Disabled => "disabled",
451            SslMode::Preferred => "preferred",
452            SslMode::Required => "required",
453            SslMode::VerifyCa => "verify-ca",
454            SslMode::VerifyFull => "verify-full",
455        })
456    }
457}
458
459impl std::str::FromStr for SslMode {
460    type Err = serde_json::Error;
461
462    fn from_str(s: &str) -> Result<Self, Self::Err> {
463        serde_json::from_value(serde_json::Value::String(s.to_owned()))
464    }
465}
466
467pub async fn create_pg_client(
468    user: &str,
469    password: &str,
470    host: &str,
471    port: &str,
472    database: &str,
473    ssl_mode: &SslMode,
474    ssl_root_cert: &Option<String>,
475    tcp_keepalive: Option<TcpKeepaliveConfig>,
476) -> anyhow::Result<PgClient> {
477    let mut pg_config = tokio_postgres::Config::new();
478    pg_config
479        .user(user)
480        .password(password)
481        .host(host)
482        .port(port.parse::<u16>().unwrap())
483        .dbname(database);
484
485    // Configure TCP keepalive if provided
486    if let Some(keepalive) = tcp_keepalive {
487        pg_config.keepalives(true);
488        pg_config.keepalives_idle(std::time::Duration::from_secs(
489            keepalive.tcp_keepalive_idle as u64,
490        ));
491        #[cfg(not(target_os = "windows"))]
492        {
493            pg_config.keepalives_interval(std::time::Duration::from_secs(
494                keepalive.tcp_keepalive_interval as u64,
495            ));
496            pg_config.keepalives_retries(keepalive.tcp_keepalive_count);
497        }
498        tracing::info!(
499            "TCP keepalive enabled: idle={}s, interval={}s, retries={}",
500            keepalive.tcp_keepalive_idle,
501            keepalive.tcp_keepalive_interval,
502            keepalive.tcp_keepalive_count
503        );
504    }
505
506    let (_verify_ca, verify_hostname) = match ssl_mode {
507        SslMode::VerifyCa => (true, false),
508        SslMode::VerifyFull => (true, true),
509        _ => (false, false),
510    };
511
512    #[cfg(not(madsim))]
513    let connector = match ssl_mode {
514        SslMode::Disabled => {
515            pg_config.ssl_mode(tokio_postgres::config::SslMode::Disable);
516            MaybeMakeTlsConnector::NoTls(NoTls)
517        }
518        SslMode::Preferred => {
519            pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
520            match SslConnector::builder(SslMethod::tls()) {
521                Ok(mut builder) => {
522                    // disable certificate verification for `prefer`
523                    builder.set_verify(SslVerifyMode::NONE);
524                    MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
525                }
526                Err(e) => {
527                    tracing::warn!(error = %e.as_report(), "SSL connector error");
528                    MaybeMakeTlsConnector::NoTls(NoTls)
529                }
530            }
531        }
532        SslMode::Required => {
533            pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
534            let mut builder = SslConnector::builder(SslMethod::tls())?;
535            // disable certificate verification for `require`
536            builder.set_verify(SslVerifyMode::NONE);
537            MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
538        }
539
540        SslMode::VerifyCa | SslMode::VerifyFull => {
541            pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
542            let mut builder = SslConnector::builder(SslMethod::tls())?;
543            if let Some(ssl_root_cert) = ssl_root_cert {
544                builder.set_ca_file(ssl_root_cert).map_err(|e| {
545                    anyhow!(format!("bad ssl root cert error: {}", e.to_report_string()))
546                })?;
547            }
548            let mut connector = MakeTlsConnector::new(builder.build());
549            if !verify_hostname {
550                connector.set_callback(|config, _| {
551                    config.set_verify_hostname(false);
552                    Ok(())
553                });
554            }
555            MaybeMakeTlsConnector::Tls(connector)
556        }
557    };
558    #[cfg(madsim)]
559    let connector = NoTls;
560
561    let (client, connection) = pg_config.connect(connector).await?;
562
563    tokio::spawn(async move {
564        if let Err(e) = connection.await {
565            tracing::error!(error = %e.as_report(), "postgres connection error");
566        }
567    });
568
569    Ok(client)
570}
571
572// Used for both source and sink connector
573pub fn sea_type_to_rw_type(col_type: &SeaType) -> ConnectorResult<DataType> {
574    let dtype = match col_type {
575        SeaType::SmallInt | SeaType::SmallSerial => DataType::Int16,
576        SeaType::Integer | SeaType::Serial => DataType::Int32,
577        SeaType::BigInt | SeaType::BigSerial => DataType::Int64,
578        SeaType::Money | SeaType::Decimal(_) | SeaType::Numeric(_) => DataType::Decimal,
579        SeaType::Real => DataType::Float32,
580        SeaType::DoublePrecision => DataType::Float64,
581        SeaType::Varchar(_) | SeaType::Char(_) | SeaType::Text => DataType::Varchar,
582        SeaType::Bytea => DataType::Bytea,
583        SeaType::Timestamp(_) => DataType::Timestamp,
584        SeaType::TimestampWithTimeZone(_) => DataType::Timestamptz,
585        SeaType::Date => DataType::Date,
586        SeaType::Time(_) | SeaType::TimeWithTimeZone(_) => DataType::Time,
587        SeaType::Interval(_) => DataType::Interval,
588        SeaType::Boolean => DataType::Boolean,
589        SeaType::Point => DataType::Struct(StructType::new(vec![
590            ("x", DataType::Float32),
591            ("y", DataType::Float32),
592        ])),
593        SeaType::Uuid => DataType::Varchar,
594        SeaType::Xml => DataType::Varchar,
595        SeaType::Json => DataType::Jsonb,
596        SeaType::JsonBinary => DataType::Jsonb,
597        SeaType::Array(def) => {
598            let item_type = match def.col_type.as_ref() {
599                Some(ty) => sea_type_to_rw_type(ty.as_ref())?,
600                None => {
601                    return Err(anyhow!("ARRAY type missing element type").into());
602                }
603            };
604
605            DataType::list(item_type)
606        }
607        SeaType::PgLsn => DataType::Int64,
608        SeaType::Cidr
609        | SeaType::Inet
610        | SeaType::MacAddr
611        | SeaType::MacAddr8
612        | SeaType::Int4Range
613        | SeaType::Int8Range
614        | SeaType::NumRange
615        | SeaType::TsRange
616        | SeaType::TsTzRange
617        | SeaType::DateRange
618        | SeaType::Enum(_) => DataType::Varchar,
619        SeaType::Line
620        | SeaType::Lseg
621        | SeaType::Box
622        | SeaType::Path
623        | SeaType::Polygon
624        | SeaType::Circle
625        | SeaType::Bit(_)
626        | SeaType::VarBit(_)
627        | SeaType::TsVector
628        | SeaType::TsQuery => {
629            bail!("{:?} type not supported", col_type);
630        }
631        SeaType::Unknown(name) => {
632            if let Some(dim) = parse_pgvector_dimension(name)? {
633                DataType::Vector(dim)
634            } else {
635                // NOTES: user-defined enum type is classified as `Unknown`
636                tracing::warn!("Unknown Postgres data type: {name}, map to varchar");
637                DataType::Varchar
638            }
639        }
640    };
641
642    Ok(dtype)
643}
644
645fn parse_pgvector_dimension(type_name: &str) -> ConnectorResult<Option<usize>> {
646    let normalized = type_name.trim().to_ascii_lowercase();
647    if normalized == "vector" {
648        bail!("pgvector type `vector` is missing dimension, expected `vector(n)`")
649    }
650    if !normalized.starts_with("vector(") || !normalized.ends_with(')') {
651        return Ok(None);
652    }
653
654    let dim_text = normalized
655        .trim_start_matches("vector(")
656        .trim_end_matches(')')
657        .trim();
658    let dim = dim_text
659        .parse::<usize>()
660        .map_err(|_| anyhow!("invalid pgvector dimension in type `{type_name}`"))?;
661
662    if !(1..=DataType::VEC_MAX_SIZE).contains(&dim) {
663        bail!(
664            "pgvector dimension out of range in type `{}`: expect 1..={}",
665            type_name,
666            DataType::VEC_MAX_SIZE
667        );
668    }
669
670    Ok(Some(dim))
671}
672
673// Used for sink connector
674// We use `sea-schema` for table schema discovery.
675// So we have to map `sea-schema` pg types
676// to `tokio-postgres` pg types (which we use for query binding).
677fn sea_type_to_pg_type(sea_type: &SeaType) -> ConnectorResult<tokio_postgres::types::Type> {
678    use tokio_postgres::types::Type as PgType;
679    match sea_type {
680        SeaType::SmallInt => Ok(PgType::INT2),
681        SeaType::Integer => Ok(PgType::INT4),
682        SeaType::BigInt => Ok(PgType::INT8),
683        SeaType::Decimal(_) => Ok(PgType::NUMERIC),
684        SeaType::Numeric(_) => Ok(PgType::NUMERIC),
685        SeaType::Real => Ok(PgType::FLOAT4),
686        SeaType::DoublePrecision => Ok(PgType::FLOAT8),
687        SeaType::Varchar(_) => Ok(PgType::VARCHAR),
688        SeaType::Char(_) => Ok(PgType::CHAR),
689        SeaType::Text => Ok(PgType::TEXT),
690        SeaType::Bytea => Ok(PgType::BYTEA),
691        SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP),
692        SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ),
693        SeaType::Date => Ok(PgType::DATE),
694        SeaType::Time(_) => Ok(PgType::TIME),
695        SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ),
696        SeaType::Interval(_) => Ok(PgType::INTERVAL),
697        SeaType::Boolean => Ok(PgType::BOOL),
698        SeaType::Point => Ok(PgType::POINT),
699        SeaType::Uuid => Ok(PgType::UUID),
700        SeaType::Json => Ok(PgType::JSON),
701        SeaType::JsonBinary => Ok(PgType::JSONB),
702        SeaType::Array(t) => {
703            let Some(t) = t.col_type.as_ref() else {
704                bail!("missing array type")
705            };
706            match t.as_ref() {
707                // RW only supports 1 level of nesting.
708                SeaType::SmallInt => Ok(PgType::INT2_ARRAY),
709                SeaType::Integer => Ok(PgType::INT4_ARRAY),
710                SeaType::BigInt => Ok(PgType::INT8_ARRAY),
711                SeaType::Decimal(_) => Ok(PgType::NUMERIC_ARRAY),
712                SeaType::Numeric(_) => Ok(PgType::NUMERIC_ARRAY),
713                SeaType::Real => Ok(PgType::FLOAT4_ARRAY),
714                SeaType::DoublePrecision => Ok(PgType::FLOAT8_ARRAY),
715                SeaType::Varchar(_) => Ok(PgType::VARCHAR_ARRAY),
716                SeaType::Char(_) => Ok(PgType::CHAR_ARRAY),
717                SeaType::Text => Ok(PgType::TEXT_ARRAY),
718                SeaType::Bytea => Ok(PgType::BYTEA_ARRAY),
719                SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP_ARRAY),
720                SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ_ARRAY),
721                SeaType::Date => Ok(PgType::DATE_ARRAY),
722                SeaType::Time(_) => Ok(PgType::TIME_ARRAY),
723                SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ_ARRAY),
724                SeaType::Interval(_) => Ok(PgType::INTERVAL_ARRAY),
725                SeaType::Boolean => Ok(PgType::BOOL_ARRAY),
726                SeaType::Point => Ok(PgType::POINT_ARRAY),
727                SeaType::Uuid => Ok(PgType::UUID_ARRAY),
728                SeaType::Json => Ok(PgType::JSON_ARRAY),
729                SeaType::JsonBinary => Ok(PgType::JSONB_ARRAY),
730                SeaType::Array(_) => bail!("nested array type is not supported"),
731                SeaType::Unknown(name) => {
732                    // Treat as enum type
733                    Ok(PgType::new(
734                        name.clone(),
735                        0,
736                        PgKind::Array(PgType::new(
737                            name.clone(),
738                            0,
739                            PgKind::Enum(vec![]),
740                            "".into(),
741                        )),
742                        "".into(),
743                    ))
744                }
745                _ => bail!("unsupported array type: {:?}", t),
746            }
747        }
748        SeaType::Unknown(name) => {
749            // Treat as enum type
750            Ok(PgType::new(
751                name.clone(),
752                0,
753                PgKind::Enum(vec![]),
754                "".into(),
755            ))
756        }
757        _ => bail!("unsupported type: {:?}", sea_type),
758    }
759}
760
761#[cfg(test)]
762mod tests {
763    use super::parse_pgvector_dimension;
764
765    #[test]
766    fn test_parse_pgvector_dimension() {
767        assert_eq!(parse_pgvector_dimension("vector(3)").unwrap(), Some(3));
768        assert_eq!(parse_pgvector_dimension("VECTOR(768)").unwrap(), Some(768));
769        assert_eq!(parse_pgvector_dimension("varchar").unwrap(), None);
770    }
771
772    #[test]
773    fn test_parse_pgvector_dimension_requires_size() {
774        let err = parse_pgvector_dimension("vector").unwrap_err();
775        assert!(err.to_string().contains("missing dimension"));
776    }
777}