risingwave_connector/connector_common/
postgres.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::fmt;
17
18use anyhow::{Context, anyhow};
19use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
20use postgres_openssl::MakeTlsConnector;
21use risingwave_common::bail;
22use risingwave_common::catalog::{ColumnDesc, ColumnId};
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use sea_schema::postgres::def::{ColumnType as SeaType, TableDef, TableInfo};
25use sea_schema::postgres::discovery::SchemaDiscovery;
26use sea_schema::sea_query::{Alias, IntoIden};
27use serde::Deserialize;
28use sqlx::postgres::{PgConnectOptions, PgSslMode};
29use sqlx::{PgPool, Row};
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Kind as PgKind;
32use tokio_postgres::{Client as PgClient, NoTls};
33
34#[cfg(not(madsim))]
35use super::maybe_tls_connector::MaybeMakeTlsConnector;
36use crate::error::ConnectorResult;
37use crate::sink::postgres::TcpKeepaliveConfig;
38
39/// SQL query to discover primary key columns directly from PostgreSQL system tables.
40/// This bypasses querying `information_schema.table_constraints` to avoid permission issues.
41/// Match `pg_class` and `pg_namespace` by exact catalog names instead of casting a
42/// constructed string to `regclass`, as unquoted `regclass` input folds mixed-case
43/// table names to lower case.
44const DISCOVER_PRIMARY_KEY_QUERY: &str = r#"
45    SELECT a.attname as column_name
46    FROM pg_index i
47    JOIN pg_class c ON c.oid = i.indrelid
48    JOIN pg_namespace n ON n.oid = c.relnamespace
49    JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
50    WHERE n.nspname = $1
51      AND c.relname = $2
52      AND i.indisprimary = true
53    ORDER BY array_position(i.indkey, a.attnum)
54"#;
55
56/// Discover pgvector columns with both `atttypmod` (dimension) and `format_type` text.
57/// `vector(n)` is stored as `atttypmod = n`, while dimension-less `vector` uses `-1`.
58/// We rely on this to keep user-defined type modifiers that are not preserved by sea-schema.
59const DISCOVER_PGVECTOR_COLUMNS_QUERY: &str = r#"
60    SELECT
61      a.attname as column_name,
62      a.atttypmod as atttypmod,
63      format_type(a.atttypid, a.atttypmod) as formatted_type
64    FROM pg_attribute a
65    JOIN pg_class c ON c.oid = a.attrelid
66    JOIN pg_namespace n ON n.oid = c.relnamespace
67    JOIN pg_type t ON t.oid = a.atttypid
68    WHERE n.nspname = $1
69      AND c.relname = $2
70      AND t.typname = 'vector'
71      AND a.attnum > 0
72      AND NOT a.attisdropped
73    ORDER BY a.attnum
74"#;
75
76pub struct PgConnectionConfig {
77    pub host: String,
78    pub port: String,
79    pub user: String,
80    pub password: String,
81    pub database: String,
82    pub ssl_mode: SslMode,
83    pub ssl_root_cert: Option<String>,
84}
85
86pub fn pg_connection_config_from_properties(
87    props: &BTreeMap<String, String>,
88) -> ConnectorResult<PgConnectionConfig> {
89    Ok(PgConnectionConfig {
90        host: props
91            .get("hostname")
92            .context("missing `hostname` in postgres-cdc properties")?
93            .clone(),
94        port: props
95            .get("port")
96            .context("missing `port` in postgres-cdc properties")?
97            .clone(),
98        user: props
99            .get("username")
100            .context("missing `username` in postgres-cdc properties")?
101            .clone(),
102        password: props.get("password").cloned().unwrap_or_default(),
103        database: props
104            .get("database.name")
105            .context("missing `database.name` in postgres-cdc properties")?
106            .clone(),
107        ssl_mode: props
108            .get("ssl.mode")
109            .and_then(|v| v.parse::<SslMode>().ok())
110            .unwrap_or_default(),
111        ssl_root_cert: props.get("ssl.root.cert").cloned(),
112    })
113}
114
115pub async fn create_pg_client_from_properties(
116    props: &BTreeMap<String, String>,
117    tcp_keepalive: Option<TcpKeepaliveConfig>,
118) -> ConnectorResult<PgClient> {
119    let config = pg_connection_config_from_properties(props)?;
120    create_pg_client(
121        &config.user,
122        &config.password,
123        &config.host,
124        &config.port,
125        &config.database,
126        &config.ssl_mode,
127        &config.ssl_root_cert,
128        tcp_keepalive,
129    )
130    .await
131    .map_err(Into::into)
132}
133
134pub async fn discover_pgvector_dimensions(
135    client: &PgClient,
136    schema: &str,
137    table: &str,
138) -> ConnectorResult<HashMap<String, usize>> {
139    let rows = client
140        .query(DISCOVER_PGVECTOR_COLUMNS_QUERY, &[&schema, &table])
141        .await?;
142
143    let mut dims = HashMap::new();
144    for row in rows {
145        let col_name: String = row.get("column_name");
146        let atttypmod: i32 = row.get("atttypmod");
147        if atttypmod > 0
148            && let Ok(dim) = usize::try_from(atttypmod)
149        {
150            dims.insert(col_name, dim);
151        }
152    }
153    Ok(dims)
154}
155
156#[derive(Debug, Clone, PartialEq, Deserialize, Default)]
157#[serde(rename_all = "lowercase")]
158pub enum SslMode {
159    #[serde(alias = "disable")]
160    Disabled,
161    #[serde(alias = "prefer")]
162    #[default]
163    Preferred,
164    #[serde(alias = "require")]
165    Required,
166    /// verify that the server is trustworthy by checking the certificate chain
167    /// up to the root certificate stored on the client.
168    #[serde(alias = "verify-ca")]
169    VerifyCa,
170    /// Besides verify the certificate, will also verify that the serverhost name
171    /// matches the name stored in the server certificate.
172    #[serde(alias = "verify-full")]
173    VerifyFull,
174}
175
176pub struct PostgresExternalTable {
177    column_descs: Vec<ColumnDesc>,
178    pk_names: Vec<String>,
179}
180
181impl PostgresExternalTable {
182    /// Discover primary key columns directly from PostgreSQL system tables.
183    /// This bypasses querying `information_schema.table_constraints` to avoid requiring table owner permissions.
184    async fn discover_primary_key(
185        connection: &PgPool,
186        schema_name: &str,
187        table_name: &str,
188    ) -> ConnectorResult<Vec<String>> {
189        let rows = sqlx::query(DISCOVER_PRIMARY_KEY_QUERY)
190            .bind(schema_name)
191            .bind(table_name)
192            .fetch_all(connection)
193            .await
194            .context("Failed to discover primary key columns")?;
195
196        let pk_columns = rows
197            .into_iter()
198            .map(|row| row.get::<String, _>("column_name"))
199            .collect();
200
201        Ok(pk_columns)
202    }
203
204    /// Discover schema with workaround for primary key discovery
205    /// This method uses direct PostgreSQL system table queries for primary keys
206    /// to avoid permission issues when querying `information_schema.table_constraints`
207    async fn discover_pk_and_full_columns(
208        username: &str,
209        password: &str,
210        host: &str,
211        port: u16,
212        database: &str,
213        schema: &str,
214        table: &str,
215        ssl_mode: &SslMode,
216        ssl_root_cert: &Option<String>,
217    ) -> ConnectorResult<(Vec<sea_schema::postgres::def::ColumnInfo>, Vec<String>)> {
218        let mut options = PgConnectOptions::new()
219            .username(username)
220            .password(password)
221            .host(host)
222            .port(port)
223            .database(database)
224            .ssl_mode(match ssl_mode {
225                SslMode::Disabled => PgSslMode::Disable,
226                SslMode::Preferred => PgSslMode::Prefer,
227                SslMode::Required => PgSslMode::Require,
228                SslMode::VerifyCa => PgSslMode::VerifyCa,
229                SslMode::VerifyFull => PgSslMode::VerifyFull,
230            });
231
232        if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
233            && let Some(root_cert) = ssl_root_cert
234        {
235            options = options.ssl_root_cert(root_cert.as_str());
236        }
237
238        let connection = PgPool::connect_with(options).await?;
239
240        // Use sea-schema only for column discovery (no permission issues)
241        let schema_discovery = SchemaDiscovery::new(connection.clone(), schema);
242        let empty_map: HashMap<String, Vec<String>> = HashMap::new();
243        let columns = schema_discovery
244            .discover_columns(
245                Alias::new(schema).into_iden(),
246                Alias::new(table).into_iden(),
247                &empty_map,
248            )
249            .await?;
250
251        let pgvector_columns = sqlx::query(DISCOVER_PGVECTOR_COLUMNS_QUERY)
252            .bind(schema)
253            .bind(table)
254            .fetch_all(&connection)
255            .await
256            .context("Failed to discover PostgreSQL pgvector columns")?;
257        let formatted_type_by_column: HashMap<String, String> = pgvector_columns
258            .into_iter()
259            .map(|row| {
260                (
261                    row.get::<String, _>("column_name"),
262                    row.get::<String, _>("formatted_type"),
263                )
264            })
265            .collect();
266
267        // sea-schema reports pgvector as `Unknown("vector")` and drops the dimension.
268        // Patch it with PostgreSQL's formatted type text so we can derive vector(n).
269        let mut columns = columns;
270        for col in &mut columns {
271            if let SeaType::Unknown(name) = &col.col_type
272                && name.eq_ignore_ascii_case("vector")
273                && let Some(formatted_type) = formatted_type_by_column.get(&col.name)
274            {
275                col.col_type = SeaType::Unknown(formatted_type.clone());
276            }
277        }
278
279        // Use direct system table query for primary key discovery
280        let pk_columns = Self::discover_primary_key(&connection, schema, table).await?;
281
282        Ok((columns, pk_columns))
283    }
284
285    async fn discover_schema(
286        username: &str,
287        password: &str,
288        host: &str,
289        port: u16,
290        database: &str,
291        schema: &str,
292        table: &str,
293        ssl_mode: &SslMode,
294        ssl_root_cert: &Option<String>,
295    ) -> ConnectorResult<TableDef> {
296        let mut options = PgConnectOptions::new()
297            .username(username)
298            .password(password)
299            .host(host)
300            .port(port)
301            .database(database)
302            .ssl_mode(match ssl_mode {
303                SslMode::Disabled => PgSslMode::Disable,
304                SslMode::Preferred => PgSslMode::Prefer,
305                SslMode::Required => PgSslMode::Require,
306                SslMode::VerifyCa => PgSslMode::VerifyCa,
307                SslMode::VerifyFull => PgSslMode::VerifyFull,
308            });
309
310        if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
311            && let Some(root_cert) = ssl_root_cert
312        {
313            options = options.ssl_root_cert(root_cert.as_str());
314        }
315
316        let connection = PgPool::connect_with(options).await?;
317        let schema_discovery = SchemaDiscovery::new(connection, schema);
318        // fetch column schema and primary key
319        let empty_map = HashMap::new();
320        let table_schema = schema_discovery
321            .discover_table(
322                TableInfo {
323                    name: table.to_owned(),
324                    of_type: None,
325                },
326                &empty_map,
327            )
328            .await?;
329        Ok(table_schema)
330    }
331
332    pub async fn connect(
333        username: &str,
334        password: &str,
335        host: &str,
336        port: u16,
337        database: &str,
338        schema: &str,
339        table: &str,
340        ssl_mode: &SslMode,
341        ssl_root_cert: &Option<String>,
342        is_append_only: bool,
343    ) -> ConnectorResult<Self> {
344        tracing::debug!("connect to postgres external table");
345
346        let (columns, pk_names) = Self::discover_pk_and_full_columns(
347            username,
348            password,
349            host,
350            port,
351            database,
352            schema,
353            table,
354            ssl_mode,
355            ssl_root_cert,
356        )
357        .await?;
358
359        let mut column_descs = vec![];
360        for col in &columns {
361            let rw_data_type = sea_type_to_rw_type(&col.col_type)?;
362            let column_desc = if let Some(ref default_expr) = col.default {
363                // parse the value of "column_default" field in information_schema.columns,
364                // non number data type will be stored as "'value'::type"
365                let val_text = default_expr
366                    .0
367                    .split("::")
368                    .map(|s| s.trim_matches('\''))
369                    .next()
370                    .expect("default value expression");
371
372                match ScalarImpl::from_text(val_text, &rw_data_type) {
373                    Ok(scalar) => ColumnDesc::named_with_default_value(
374                        col.name.clone(),
375                        ColumnId::placeholder(),
376                        rw_data_type.clone(),
377                        Some(scalar),
378                    ),
379                    Err(err) => {
380                        tracing::warn!(error=%err.as_report(), "failed to parse postgres default value expression, only constant is supported");
381                        ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
382                    }
383                }
384            } else {
385                ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
386            };
387            column_descs.push(column_desc);
388        }
389
390        // Check primary key existence using the directly discovered pk_names
391        if !is_append_only && pk_names.is_empty() {
392            return Err(anyhow!(
393                "Postgres table should define the primary key for non-append-only tables"
394            )
395            .into());
396        }
397
398        Ok(Self {
399            column_descs,
400            pk_names,
401        })
402    }
403
404    // return the mapping from column name to pg type, the pg type is used for writing data to postgres
405    pub async fn type_mapping(
406        username: &str,
407        password: &str,
408        host: &str,
409        port: u16,
410        database: &str,
411        schema: &str,
412        table: &str,
413        ssl_mode: &SslMode,
414        ssl_root_cert: &Option<String>,
415        is_append_only: bool,
416    ) -> ConnectorResult<HashMap<String, tokio_postgres::types::Type>> {
417        tracing::debug!("connect to postgres external table to get type mapping");
418        let table_schema = Self::discover_schema(
419            username,
420            password,
421            host,
422            port,
423            database,
424            schema,
425            table,
426            ssl_mode,
427            ssl_root_cert,
428        )
429        .await?;
430        let mut column_name_to_pg_type = HashMap::new();
431        for col in &table_schema.columns {
432            let pg_type = sea_type_to_pg_type(&col.col_type)?;
433            column_name_to_pg_type.insert(col.name.clone(), pg_type);
434        }
435        if !is_append_only && table_schema.primary_key_constraints.is_empty() {
436            return Err(anyhow!(
437                "Postgres table should define the primary key for non-append-only tables"
438            )
439            .into());
440        }
441        Ok(column_name_to_pg_type)
442    }
443
444    pub fn column_descs(&self) -> &Vec<ColumnDesc> {
445        &self.column_descs
446    }
447
448    pub fn pk_names(&self) -> &Vec<String> {
449        &self.pk_names
450    }
451}
452
453impl fmt::Display for SslMode {
454    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
455        f.write_str(match self {
456            SslMode::Disabled => "disabled",
457            SslMode::Preferred => "preferred",
458            SslMode::Required => "required",
459            SslMode::VerifyCa => "verify-ca",
460            SslMode::VerifyFull => "verify-full",
461        })
462    }
463}
464
465impl std::str::FromStr for SslMode {
466    type Err = serde_json::Error;
467
468    fn from_str(s: &str) -> Result<Self, Self::Err> {
469        serde_json::from_value(serde_json::Value::String(s.to_owned()))
470    }
471}
472
473pub async fn create_pg_client(
474    user: &str,
475    password: &str,
476    host: &str,
477    port: &str,
478    database: &str,
479    ssl_mode: &SslMode,
480    ssl_root_cert: &Option<String>,
481    tcp_keepalive: Option<TcpKeepaliveConfig>,
482) -> anyhow::Result<PgClient> {
483    let mut pg_config = tokio_postgres::Config::new();
484    pg_config
485        .user(user)
486        .password(password)
487        .host(host)
488        .port(port.parse::<u16>().unwrap())
489        .dbname(database);
490
491    // Configure TCP keepalive if provided
492    if let Some(keepalive) = tcp_keepalive {
493        pg_config.keepalives(true);
494        pg_config.keepalives_idle(std::time::Duration::from_secs(
495            keepalive.tcp_keepalive_idle as u64,
496        ));
497        #[cfg(not(target_os = "windows"))]
498        {
499            pg_config.keepalives_interval(std::time::Duration::from_secs(
500                keepalive.tcp_keepalive_interval as u64,
501            ));
502            pg_config.keepalives_retries(keepalive.tcp_keepalive_count);
503        }
504        tracing::info!(
505            "TCP keepalive enabled: idle={}s, interval={}s, retries={}",
506            keepalive.tcp_keepalive_idle,
507            keepalive.tcp_keepalive_interval,
508            keepalive.tcp_keepalive_count
509        );
510    }
511
512    let (_verify_ca, verify_hostname) = match ssl_mode {
513        SslMode::VerifyCa => (true, false),
514        SslMode::VerifyFull => (true, true),
515        _ => (false, false),
516    };
517
518    #[cfg(not(madsim))]
519    let connector = match ssl_mode {
520        SslMode::Disabled => {
521            pg_config.ssl_mode(tokio_postgres::config::SslMode::Disable);
522            MaybeMakeTlsConnector::NoTls(NoTls)
523        }
524        SslMode::Preferred => {
525            pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
526            match SslConnector::builder(SslMethod::tls()) {
527                Ok(mut builder) => {
528                    // disable certificate verification for `prefer`
529                    builder.set_verify(SslVerifyMode::NONE);
530                    MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
531                }
532                Err(e) => {
533                    tracing::warn!(error = %e.as_report(), "SSL connector error");
534                    MaybeMakeTlsConnector::NoTls(NoTls)
535                }
536            }
537        }
538        SslMode::Required => {
539            pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
540            let mut builder = SslConnector::builder(SslMethod::tls())?;
541            // disable certificate verification for `require`
542            builder.set_verify(SslVerifyMode::NONE);
543            MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
544        }
545
546        SslMode::VerifyCa | SslMode::VerifyFull => {
547            pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
548            let mut builder = SslConnector::builder(SslMethod::tls())?;
549            if let Some(ssl_root_cert) = ssl_root_cert {
550                builder.set_ca_file(ssl_root_cert).map_err(|e| {
551                    anyhow!(format!("bad ssl root cert error: {}", e.to_report_string()))
552                })?;
553            }
554            let mut connector = MakeTlsConnector::new(builder.build());
555            if !verify_hostname {
556                connector.set_callback(|config, _| {
557                    config.set_verify_hostname(false);
558                    Ok(())
559                });
560            }
561            MaybeMakeTlsConnector::Tls(connector)
562        }
563    };
564    #[cfg(madsim)]
565    let connector = NoTls;
566
567    let (client, connection) = pg_config.connect(connector).await?;
568
569    tokio::spawn(async move {
570        if let Err(e) = connection.await {
571            tracing::error!(error = %e.as_report(), "postgres connection error");
572        }
573    });
574
575    Ok(client)
576}
577
578// Used for both source and sink connector
579pub fn sea_type_to_rw_type(col_type: &SeaType) -> ConnectorResult<DataType> {
580    let dtype = match col_type {
581        SeaType::SmallInt | SeaType::SmallSerial => DataType::Int16,
582        SeaType::Integer | SeaType::Serial => DataType::Int32,
583        SeaType::BigInt | SeaType::BigSerial => DataType::Int64,
584        SeaType::Money | SeaType::Decimal(_) | SeaType::Numeric(_) => DataType::Decimal,
585        SeaType::Real => DataType::Float32,
586        SeaType::DoublePrecision => DataType::Float64,
587        SeaType::Varchar(_) | SeaType::Char(_) | SeaType::Text => DataType::Varchar,
588        SeaType::Bytea => DataType::Bytea,
589        SeaType::Timestamp(_) => DataType::Timestamp,
590        SeaType::TimestampWithTimeZone(_) => DataType::Timestamptz,
591        SeaType::Date => DataType::Date,
592        SeaType::Time(_) | SeaType::TimeWithTimeZone(_) => DataType::Time,
593        SeaType::Interval(_) => DataType::Interval,
594        SeaType::Boolean => DataType::Boolean,
595        SeaType::Point => DataType::Struct(StructType::new(vec![
596            ("x", DataType::Float32),
597            ("y", DataType::Float32),
598        ])),
599        SeaType::Uuid => DataType::Varchar,
600        SeaType::Xml => DataType::Varchar,
601        SeaType::Json => DataType::Jsonb,
602        SeaType::JsonBinary => DataType::Jsonb,
603        SeaType::Array(def) => {
604            let item_type = match def.col_type.as_ref() {
605                Some(ty) => sea_type_to_rw_type(ty.as_ref())?,
606                None => {
607                    return Err(anyhow!("ARRAY type missing element type").into());
608                }
609            };
610
611            DataType::list(item_type)
612        }
613        SeaType::PgLsn => DataType::Int64,
614        SeaType::Cidr
615        | SeaType::Inet
616        | SeaType::MacAddr
617        | SeaType::MacAddr8
618        | SeaType::Int4Range
619        | SeaType::Int8Range
620        | SeaType::NumRange
621        | SeaType::TsRange
622        | SeaType::TsTzRange
623        | SeaType::DateRange
624        | SeaType::Enum(_) => DataType::Varchar,
625        SeaType::Line
626        | SeaType::Lseg
627        | SeaType::Box
628        | SeaType::Path
629        | SeaType::Polygon
630        | SeaType::Circle
631        | SeaType::Bit(_)
632        | SeaType::VarBit(_)
633        | SeaType::TsVector
634        | SeaType::TsQuery => {
635            bail!("{:?} type not supported", col_type);
636        }
637        SeaType::Unknown(name) => {
638            if let Some(dim) = parse_pgvector_dimension(name)? {
639                DataType::Vector(dim)
640            } else {
641                // NOTES: user-defined enum type is classified as `Unknown`
642                tracing::warn!("Unknown Postgres data type: {name}, map to varchar");
643                DataType::Varchar
644            }
645        }
646    };
647
648    Ok(dtype)
649}
650
651fn parse_pgvector_dimension(type_name: &str) -> ConnectorResult<Option<usize>> {
652    let normalized = type_name.trim().to_ascii_lowercase();
653    if normalized == "vector" {
654        bail!("pgvector type `vector` is missing dimension, expected `vector(n)`")
655    }
656    if !normalized.starts_with("vector(") || !normalized.ends_with(')') {
657        return Ok(None);
658    }
659
660    let dim_text = normalized
661        .trim_start_matches("vector(")
662        .trim_end_matches(')')
663        .trim();
664    let dim = dim_text
665        .parse::<usize>()
666        .map_err(|_| anyhow!("invalid pgvector dimension in type `{type_name}`"))?;
667
668    if !(1..=DataType::VEC_MAX_SIZE).contains(&dim) {
669        bail!(
670            "pgvector dimension out of range in type `{}`: expect 1..={}",
671            type_name,
672            DataType::VEC_MAX_SIZE
673        );
674    }
675
676    Ok(Some(dim))
677}
678
679// Used for sink connector
680// We use `sea-schema` for table schema discovery.
681// So we have to map `sea-schema` pg types
682// to `tokio-postgres` pg types (which we use for query binding).
683fn sea_type_to_pg_type(sea_type: &SeaType) -> ConnectorResult<tokio_postgres::types::Type> {
684    use tokio_postgres::types::Type as PgType;
685    match sea_type {
686        SeaType::SmallInt => Ok(PgType::INT2),
687        SeaType::Integer => Ok(PgType::INT4),
688        SeaType::BigInt => Ok(PgType::INT8),
689        SeaType::Decimal(_) => Ok(PgType::NUMERIC),
690        SeaType::Numeric(_) => Ok(PgType::NUMERIC),
691        SeaType::Real => Ok(PgType::FLOAT4),
692        SeaType::DoublePrecision => Ok(PgType::FLOAT8),
693        SeaType::Varchar(_) => Ok(PgType::VARCHAR),
694        SeaType::Char(_) => Ok(PgType::CHAR),
695        SeaType::Text => Ok(PgType::TEXT),
696        SeaType::Bytea => Ok(PgType::BYTEA),
697        SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP),
698        SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ),
699        SeaType::Date => Ok(PgType::DATE),
700        SeaType::Time(_) => Ok(PgType::TIME),
701        SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ),
702        SeaType::Interval(_) => Ok(PgType::INTERVAL),
703        SeaType::Boolean => Ok(PgType::BOOL),
704        SeaType::Point => Ok(PgType::POINT),
705        SeaType::Uuid => Ok(PgType::UUID),
706        SeaType::Json => Ok(PgType::JSON),
707        SeaType::JsonBinary => Ok(PgType::JSONB),
708        SeaType::Array(t) => {
709            let Some(t) = t.col_type.as_ref() else {
710                bail!("missing array type")
711            };
712            match t.as_ref() {
713                // RW only supports 1 level of nesting.
714                SeaType::SmallInt => Ok(PgType::INT2_ARRAY),
715                SeaType::Integer => Ok(PgType::INT4_ARRAY),
716                SeaType::BigInt => Ok(PgType::INT8_ARRAY),
717                SeaType::Decimal(_) => Ok(PgType::NUMERIC_ARRAY),
718                SeaType::Numeric(_) => Ok(PgType::NUMERIC_ARRAY),
719                SeaType::Real => Ok(PgType::FLOAT4_ARRAY),
720                SeaType::DoublePrecision => Ok(PgType::FLOAT8_ARRAY),
721                SeaType::Varchar(_) => Ok(PgType::VARCHAR_ARRAY),
722                SeaType::Char(_) => Ok(PgType::CHAR_ARRAY),
723                SeaType::Text => Ok(PgType::TEXT_ARRAY),
724                SeaType::Bytea => Ok(PgType::BYTEA_ARRAY),
725                SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP_ARRAY),
726                SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ_ARRAY),
727                SeaType::Date => Ok(PgType::DATE_ARRAY),
728                SeaType::Time(_) => Ok(PgType::TIME_ARRAY),
729                SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ_ARRAY),
730                SeaType::Interval(_) => Ok(PgType::INTERVAL_ARRAY),
731                SeaType::Boolean => Ok(PgType::BOOL_ARRAY),
732                SeaType::Point => Ok(PgType::POINT_ARRAY),
733                SeaType::Uuid => Ok(PgType::UUID_ARRAY),
734                SeaType::Json => Ok(PgType::JSON_ARRAY),
735                SeaType::JsonBinary => Ok(PgType::JSONB_ARRAY),
736                SeaType::Array(_) => bail!("nested array type is not supported"),
737                SeaType::Unknown(name) => {
738                    // Treat as enum type
739                    Ok(PgType::new(
740                        name.clone(),
741                        0,
742                        PgKind::Array(PgType::new(
743                            name.clone(),
744                            0,
745                            PgKind::Enum(vec![]),
746                            "".into(),
747                        )),
748                        "".into(),
749                    ))
750                }
751                _ => bail!("unsupported array type: {:?}", t),
752            }
753        }
754        SeaType::Unknown(name) => {
755            // Treat as enum type
756            Ok(PgType::new(
757                name.clone(),
758                0,
759                PgKind::Enum(vec![]),
760                "".into(),
761            ))
762        }
763        _ => bail!("unsupported type: {:?}", sea_type),
764    }
765}
766
767#[cfg(test)]
768mod tests {
769    use super::parse_pgvector_dimension;
770
771    #[test]
772    fn test_parse_pgvector_dimension() {
773        assert_eq!(parse_pgvector_dimension("vector(3)").unwrap(), Some(3));
774        assert_eq!(parse_pgvector_dimension("VECTOR(768)").unwrap(), Some(768));
775        assert_eq!(parse_pgvector_dimension("varchar").unwrap(), None);
776    }
777
778    #[test]
779    fn test_parse_pgvector_dimension_requires_size() {
780        let err = parse_pgvector_dimension("vector").unwrap_err();
781        assert!(err.to_string().contains("missing dimension"));
782    }
783}