1use std::collections::{BTreeMap, HashMap};
16use std::fmt;
17
18use anyhow::{Context, anyhow};
19use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
20use postgres_openssl::MakeTlsConnector;
21use risingwave_common::bail;
22use risingwave_common::catalog::{ColumnDesc, ColumnId};
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use sea_schema::postgres::def::{ColumnType as SeaType, TableDef, TableInfo};
25use sea_schema::postgres::discovery::SchemaDiscovery;
26use sea_schema::sea_query::{Alias, IntoIden};
27use serde::Deserialize;
28use sqlx::postgres::{PgConnectOptions, PgSslMode};
29use sqlx::{PgPool, Row};
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Kind as PgKind;
32use tokio_postgres::{Client as PgClient, NoTls};
33
34#[cfg(not(madsim))]
35use super::maybe_tls_connector::MaybeMakeTlsConnector;
36use crate::error::ConnectorResult;
37use crate::sink::postgres::TcpKeepaliveConfig;
38
39const DISCOVER_PRIMARY_KEY_QUERY: &str = r#"
45 SELECT a.attname as column_name
46 FROM pg_index i
47 JOIN pg_class c ON c.oid = i.indrelid
48 JOIN pg_namespace n ON n.oid = c.relnamespace
49 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
50 WHERE n.nspname = $1
51 AND c.relname = $2
52 AND i.indisprimary = true
53 ORDER BY array_position(i.indkey, a.attnum)
54"#;
55
56const DISCOVER_PGVECTOR_COLUMNS_QUERY: &str = r#"
60 SELECT
61 a.attname as column_name,
62 a.atttypmod as atttypmod,
63 format_type(a.atttypid, a.atttypmod) as formatted_type
64 FROM pg_attribute a
65 JOIN pg_class c ON c.oid = a.attrelid
66 JOIN pg_namespace n ON n.oid = c.relnamespace
67 JOIN pg_type t ON t.oid = a.atttypid
68 WHERE n.nspname = $1
69 AND c.relname = $2
70 AND t.typname = 'vector'
71 AND a.attnum > 0
72 AND NOT a.attisdropped
73 ORDER BY a.attnum
74"#;
75
76pub struct PgConnectionConfig {
77 pub host: String,
78 pub port: String,
79 pub user: String,
80 pub password: String,
81 pub database: String,
82 pub ssl_mode: SslMode,
83 pub ssl_root_cert: Option<String>,
84}
85
86pub fn pg_connection_config_from_properties(
87 props: &BTreeMap<String, String>,
88) -> ConnectorResult<PgConnectionConfig> {
89 Ok(PgConnectionConfig {
90 host: props
91 .get("hostname")
92 .context("missing `hostname` in postgres-cdc properties")?
93 .clone(),
94 port: props
95 .get("port")
96 .context("missing `port` in postgres-cdc properties")?
97 .clone(),
98 user: props
99 .get("username")
100 .context("missing `username` in postgres-cdc properties")?
101 .clone(),
102 password: props.get("password").cloned().unwrap_or_default(),
103 database: props
104 .get("database.name")
105 .context("missing `database.name` in postgres-cdc properties")?
106 .clone(),
107 ssl_mode: props
108 .get("ssl.mode")
109 .and_then(|v| v.parse::<SslMode>().ok())
110 .unwrap_or_default(),
111 ssl_root_cert: props.get("ssl.root.cert").cloned(),
112 })
113}
114
115pub async fn create_pg_client_from_properties(
116 props: &BTreeMap<String, String>,
117 tcp_keepalive: Option<TcpKeepaliveConfig>,
118) -> ConnectorResult<PgClient> {
119 let config = pg_connection_config_from_properties(props)?;
120 create_pg_client(
121 &config.user,
122 &config.password,
123 &config.host,
124 &config.port,
125 &config.database,
126 &config.ssl_mode,
127 &config.ssl_root_cert,
128 tcp_keepalive,
129 )
130 .await
131 .map_err(Into::into)
132}
133
134pub async fn discover_pgvector_dimensions(
135 client: &PgClient,
136 schema: &str,
137 table: &str,
138) -> ConnectorResult<HashMap<String, usize>> {
139 let rows = client
140 .query(DISCOVER_PGVECTOR_COLUMNS_QUERY, &[&schema, &table])
141 .await?;
142
143 let mut dims = HashMap::new();
144 for row in rows {
145 let col_name: String = row.get("column_name");
146 let atttypmod: i32 = row.get("atttypmod");
147 if atttypmod > 0
148 && let Ok(dim) = usize::try_from(atttypmod)
149 {
150 dims.insert(col_name, dim);
151 }
152 }
153 Ok(dims)
154}
155
156#[derive(Debug, Clone, PartialEq, Deserialize, Default)]
157#[serde(rename_all = "lowercase")]
158pub enum SslMode {
159 #[serde(alias = "disable")]
160 Disabled,
161 #[serde(alias = "prefer")]
162 #[default]
163 Preferred,
164 #[serde(alias = "require")]
165 Required,
166 #[serde(alias = "verify-ca")]
169 VerifyCa,
170 #[serde(alias = "verify-full")]
173 VerifyFull,
174}
175
176pub struct PostgresExternalTable {
177 column_descs: Vec<ColumnDesc>,
178 pk_names: Vec<String>,
179}
180
181impl PostgresExternalTable {
182 async fn discover_primary_key(
185 connection: &PgPool,
186 schema_name: &str,
187 table_name: &str,
188 ) -> ConnectorResult<Vec<String>> {
189 let rows = sqlx::query(DISCOVER_PRIMARY_KEY_QUERY)
190 .bind(schema_name)
191 .bind(table_name)
192 .fetch_all(connection)
193 .await
194 .context("Failed to discover primary key columns")?;
195
196 let pk_columns = rows
197 .into_iter()
198 .map(|row| row.get::<String, _>("column_name"))
199 .collect();
200
201 Ok(pk_columns)
202 }
203
204 async fn discover_pk_and_full_columns(
208 username: &str,
209 password: &str,
210 host: &str,
211 port: u16,
212 database: &str,
213 schema: &str,
214 table: &str,
215 ssl_mode: &SslMode,
216 ssl_root_cert: &Option<String>,
217 ) -> ConnectorResult<(Vec<sea_schema::postgres::def::ColumnInfo>, Vec<String>)> {
218 let mut options = PgConnectOptions::new()
219 .username(username)
220 .password(password)
221 .host(host)
222 .port(port)
223 .database(database)
224 .ssl_mode(match ssl_mode {
225 SslMode::Disabled => PgSslMode::Disable,
226 SslMode::Preferred => PgSslMode::Prefer,
227 SslMode::Required => PgSslMode::Require,
228 SslMode::VerifyCa => PgSslMode::VerifyCa,
229 SslMode::VerifyFull => PgSslMode::VerifyFull,
230 });
231
232 if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
233 && let Some(root_cert) = ssl_root_cert
234 {
235 options = options.ssl_root_cert(root_cert.as_str());
236 }
237
238 let connection = PgPool::connect_with(options).await?;
239
240 let schema_discovery = SchemaDiscovery::new(connection.clone(), schema);
242 let empty_map: HashMap<String, Vec<String>> = HashMap::new();
243 let columns = schema_discovery
244 .discover_columns(
245 Alias::new(schema).into_iden(),
246 Alias::new(table).into_iden(),
247 &empty_map,
248 )
249 .await?;
250
251 let pgvector_columns = sqlx::query(DISCOVER_PGVECTOR_COLUMNS_QUERY)
252 .bind(schema)
253 .bind(table)
254 .fetch_all(&connection)
255 .await
256 .context("Failed to discover PostgreSQL pgvector columns")?;
257 let formatted_type_by_column: HashMap<String, String> = pgvector_columns
258 .into_iter()
259 .map(|row| {
260 (
261 row.get::<String, _>("column_name"),
262 row.get::<String, _>("formatted_type"),
263 )
264 })
265 .collect();
266
267 let mut columns = columns;
270 for col in &mut columns {
271 if let SeaType::Unknown(name) = &col.col_type
272 && name.eq_ignore_ascii_case("vector")
273 && let Some(formatted_type) = formatted_type_by_column.get(&col.name)
274 {
275 col.col_type = SeaType::Unknown(formatted_type.clone());
276 }
277 }
278
279 let pk_columns = Self::discover_primary_key(&connection, schema, table).await?;
281
282 Ok((columns, pk_columns))
283 }
284
285 async fn discover_schema(
286 username: &str,
287 password: &str,
288 host: &str,
289 port: u16,
290 database: &str,
291 schema: &str,
292 table: &str,
293 ssl_mode: &SslMode,
294 ssl_root_cert: &Option<String>,
295 ) -> ConnectorResult<TableDef> {
296 let mut options = PgConnectOptions::new()
297 .username(username)
298 .password(password)
299 .host(host)
300 .port(port)
301 .database(database)
302 .ssl_mode(match ssl_mode {
303 SslMode::Disabled => PgSslMode::Disable,
304 SslMode::Preferred => PgSslMode::Prefer,
305 SslMode::Required => PgSslMode::Require,
306 SslMode::VerifyCa => PgSslMode::VerifyCa,
307 SslMode::VerifyFull => PgSslMode::VerifyFull,
308 });
309
310 if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
311 && let Some(root_cert) = ssl_root_cert
312 {
313 options = options.ssl_root_cert(root_cert.as_str());
314 }
315
316 let connection = PgPool::connect_with(options).await?;
317 let schema_discovery = SchemaDiscovery::new(connection, schema);
318 let empty_map = HashMap::new();
320 let table_schema = schema_discovery
321 .discover_table(
322 TableInfo {
323 name: table.to_owned(),
324 of_type: None,
325 },
326 &empty_map,
327 )
328 .await?;
329 Ok(table_schema)
330 }
331
332 pub async fn connect(
333 username: &str,
334 password: &str,
335 host: &str,
336 port: u16,
337 database: &str,
338 schema: &str,
339 table: &str,
340 ssl_mode: &SslMode,
341 ssl_root_cert: &Option<String>,
342 is_append_only: bool,
343 ) -> ConnectorResult<Self> {
344 tracing::debug!("connect to postgres external table");
345
346 let (columns, pk_names) = Self::discover_pk_and_full_columns(
347 username,
348 password,
349 host,
350 port,
351 database,
352 schema,
353 table,
354 ssl_mode,
355 ssl_root_cert,
356 )
357 .await?;
358
359 let mut column_descs = vec![];
360 for col in &columns {
361 let rw_data_type = sea_type_to_rw_type(&col.col_type)?;
362 let column_desc = if let Some(ref default_expr) = col.default {
363 let val_text = default_expr
366 .0
367 .split("::")
368 .map(|s| s.trim_matches('\''))
369 .next()
370 .expect("default value expression");
371
372 match ScalarImpl::from_text(val_text, &rw_data_type) {
373 Ok(scalar) => ColumnDesc::named_with_default_value(
374 col.name.clone(),
375 ColumnId::placeholder(),
376 rw_data_type.clone(),
377 Some(scalar),
378 ),
379 Err(err) => {
380 tracing::warn!(error=%err.as_report(), "failed to parse postgres default value expression, only constant is supported");
381 ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
382 }
383 }
384 } else {
385 ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
386 };
387 column_descs.push(column_desc);
388 }
389
390 if !is_append_only && pk_names.is_empty() {
392 return Err(anyhow!(
393 "Postgres table should define the primary key for non-append-only tables"
394 )
395 .into());
396 }
397
398 Ok(Self {
399 column_descs,
400 pk_names,
401 })
402 }
403
404 pub async fn type_mapping(
406 username: &str,
407 password: &str,
408 host: &str,
409 port: u16,
410 database: &str,
411 schema: &str,
412 table: &str,
413 ssl_mode: &SslMode,
414 ssl_root_cert: &Option<String>,
415 is_append_only: bool,
416 ) -> ConnectorResult<HashMap<String, tokio_postgres::types::Type>> {
417 tracing::debug!("connect to postgres external table to get type mapping");
418 let table_schema = Self::discover_schema(
419 username,
420 password,
421 host,
422 port,
423 database,
424 schema,
425 table,
426 ssl_mode,
427 ssl_root_cert,
428 )
429 .await?;
430 let mut column_name_to_pg_type = HashMap::new();
431 for col in &table_schema.columns {
432 let pg_type = sea_type_to_pg_type(&col.col_type)?;
433 column_name_to_pg_type.insert(col.name.clone(), pg_type);
434 }
435 if !is_append_only && table_schema.primary_key_constraints.is_empty() {
436 return Err(anyhow!(
437 "Postgres table should define the primary key for non-append-only tables"
438 )
439 .into());
440 }
441 Ok(column_name_to_pg_type)
442 }
443
444 pub fn column_descs(&self) -> &Vec<ColumnDesc> {
445 &self.column_descs
446 }
447
448 pub fn pk_names(&self) -> &Vec<String> {
449 &self.pk_names
450 }
451}
452
453impl fmt::Display for SslMode {
454 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
455 f.write_str(match self {
456 SslMode::Disabled => "disabled",
457 SslMode::Preferred => "preferred",
458 SslMode::Required => "required",
459 SslMode::VerifyCa => "verify-ca",
460 SslMode::VerifyFull => "verify-full",
461 })
462 }
463}
464
465impl std::str::FromStr for SslMode {
466 type Err = serde_json::Error;
467
468 fn from_str(s: &str) -> Result<Self, Self::Err> {
469 serde_json::from_value(serde_json::Value::String(s.to_owned()))
470 }
471}
472
473pub async fn create_pg_client(
474 user: &str,
475 password: &str,
476 host: &str,
477 port: &str,
478 database: &str,
479 ssl_mode: &SslMode,
480 ssl_root_cert: &Option<String>,
481 tcp_keepalive: Option<TcpKeepaliveConfig>,
482) -> anyhow::Result<PgClient> {
483 let mut pg_config = tokio_postgres::Config::new();
484 pg_config
485 .user(user)
486 .password(password)
487 .host(host)
488 .port(port.parse::<u16>().unwrap())
489 .dbname(database);
490
491 if let Some(keepalive) = tcp_keepalive {
493 pg_config.keepalives(true);
494 pg_config.keepalives_idle(std::time::Duration::from_secs(
495 keepalive.tcp_keepalive_idle as u64,
496 ));
497 #[cfg(not(target_os = "windows"))]
498 {
499 pg_config.keepalives_interval(std::time::Duration::from_secs(
500 keepalive.tcp_keepalive_interval as u64,
501 ));
502 pg_config.keepalives_retries(keepalive.tcp_keepalive_count);
503 }
504 tracing::info!(
505 "TCP keepalive enabled: idle={}s, interval={}s, retries={}",
506 keepalive.tcp_keepalive_idle,
507 keepalive.tcp_keepalive_interval,
508 keepalive.tcp_keepalive_count
509 );
510 }
511
512 let (_verify_ca, verify_hostname) = match ssl_mode {
513 SslMode::VerifyCa => (true, false),
514 SslMode::VerifyFull => (true, true),
515 _ => (false, false),
516 };
517
518 #[cfg(not(madsim))]
519 let connector = match ssl_mode {
520 SslMode::Disabled => {
521 pg_config.ssl_mode(tokio_postgres::config::SslMode::Disable);
522 MaybeMakeTlsConnector::NoTls(NoTls)
523 }
524 SslMode::Preferred => {
525 pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
526 match SslConnector::builder(SslMethod::tls()) {
527 Ok(mut builder) => {
528 builder.set_verify(SslVerifyMode::NONE);
530 MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
531 }
532 Err(e) => {
533 tracing::warn!(error = %e.as_report(), "SSL connector error");
534 MaybeMakeTlsConnector::NoTls(NoTls)
535 }
536 }
537 }
538 SslMode::Required => {
539 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
540 let mut builder = SslConnector::builder(SslMethod::tls())?;
541 builder.set_verify(SslVerifyMode::NONE);
543 MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
544 }
545
546 SslMode::VerifyCa | SslMode::VerifyFull => {
547 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
548 let mut builder = SslConnector::builder(SslMethod::tls())?;
549 if let Some(ssl_root_cert) = ssl_root_cert {
550 builder.set_ca_file(ssl_root_cert).map_err(|e| {
551 anyhow!(format!("bad ssl root cert error: {}", e.to_report_string()))
552 })?;
553 }
554 let mut connector = MakeTlsConnector::new(builder.build());
555 if !verify_hostname {
556 connector.set_callback(|config, _| {
557 config.set_verify_hostname(false);
558 Ok(())
559 });
560 }
561 MaybeMakeTlsConnector::Tls(connector)
562 }
563 };
564 #[cfg(madsim)]
565 let connector = NoTls;
566
567 let (client, connection) = pg_config.connect(connector).await?;
568
569 tokio::spawn(async move {
570 if let Err(e) = connection.await {
571 tracing::error!(error = %e.as_report(), "postgres connection error");
572 }
573 });
574
575 Ok(client)
576}
577
578pub fn sea_type_to_rw_type(col_type: &SeaType) -> ConnectorResult<DataType> {
580 let dtype = match col_type {
581 SeaType::SmallInt | SeaType::SmallSerial => DataType::Int16,
582 SeaType::Integer | SeaType::Serial => DataType::Int32,
583 SeaType::BigInt | SeaType::BigSerial => DataType::Int64,
584 SeaType::Money | SeaType::Decimal(_) | SeaType::Numeric(_) => DataType::Decimal,
585 SeaType::Real => DataType::Float32,
586 SeaType::DoublePrecision => DataType::Float64,
587 SeaType::Varchar(_) | SeaType::Char(_) | SeaType::Text => DataType::Varchar,
588 SeaType::Bytea => DataType::Bytea,
589 SeaType::Timestamp(_) => DataType::Timestamp,
590 SeaType::TimestampWithTimeZone(_) => DataType::Timestamptz,
591 SeaType::Date => DataType::Date,
592 SeaType::Time(_) | SeaType::TimeWithTimeZone(_) => DataType::Time,
593 SeaType::Interval(_) => DataType::Interval,
594 SeaType::Boolean => DataType::Boolean,
595 SeaType::Point => DataType::Struct(StructType::new(vec![
596 ("x", DataType::Float32),
597 ("y", DataType::Float32),
598 ])),
599 SeaType::Uuid => DataType::Varchar,
600 SeaType::Xml => DataType::Varchar,
601 SeaType::Json => DataType::Jsonb,
602 SeaType::JsonBinary => DataType::Jsonb,
603 SeaType::Array(def) => {
604 let item_type = match def.col_type.as_ref() {
605 Some(ty) => sea_type_to_rw_type(ty.as_ref())?,
606 None => {
607 return Err(anyhow!("ARRAY type missing element type").into());
608 }
609 };
610
611 DataType::list(item_type)
612 }
613 SeaType::PgLsn => DataType::Int64,
614 SeaType::Cidr
615 | SeaType::Inet
616 | SeaType::MacAddr
617 | SeaType::MacAddr8
618 | SeaType::Int4Range
619 | SeaType::Int8Range
620 | SeaType::NumRange
621 | SeaType::TsRange
622 | SeaType::TsTzRange
623 | SeaType::DateRange
624 | SeaType::Enum(_) => DataType::Varchar,
625 SeaType::Line
626 | SeaType::Lseg
627 | SeaType::Box
628 | SeaType::Path
629 | SeaType::Polygon
630 | SeaType::Circle
631 | SeaType::Bit(_)
632 | SeaType::VarBit(_)
633 | SeaType::TsVector
634 | SeaType::TsQuery => {
635 bail!("{:?} type not supported", col_type);
636 }
637 SeaType::Unknown(name) => {
638 if let Some(dim) = parse_pgvector_dimension(name)? {
639 DataType::Vector(dim)
640 } else {
641 tracing::warn!("Unknown Postgres data type: {name}, map to varchar");
643 DataType::Varchar
644 }
645 }
646 };
647
648 Ok(dtype)
649}
650
651fn parse_pgvector_dimension(type_name: &str) -> ConnectorResult<Option<usize>> {
652 let normalized = type_name.trim().to_ascii_lowercase();
653 if normalized == "vector" {
654 bail!("pgvector type `vector` is missing dimension, expected `vector(n)`")
655 }
656 if !normalized.starts_with("vector(") || !normalized.ends_with(')') {
657 return Ok(None);
658 }
659
660 let dim_text = normalized
661 .trim_start_matches("vector(")
662 .trim_end_matches(')')
663 .trim();
664 let dim = dim_text
665 .parse::<usize>()
666 .map_err(|_| anyhow!("invalid pgvector dimension in type `{type_name}`"))?;
667
668 if !(1..=DataType::VEC_MAX_SIZE).contains(&dim) {
669 bail!(
670 "pgvector dimension out of range in type `{}`: expect 1..={}",
671 type_name,
672 DataType::VEC_MAX_SIZE
673 );
674 }
675
676 Ok(Some(dim))
677}
678
679fn sea_type_to_pg_type(sea_type: &SeaType) -> ConnectorResult<tokio_postgres::types::Type> {
684 use tokio_postgres::types::Type as PgType;
685 match sea_type {
686 SeaType::SmallInt => Ok(PgType::INT2),
687 SeaType::Integer => Ok(PgType::INT4),
688 SeaType::BigInt => Ok(PgType::INT8),
689 SeaType::Decimal(_) => Ok(PgType::NUMERIC),
690 SeaType::Numeric(_) => Ok(PgType::NUMERIC),
691 SeaType::Real => Ok(PgType::FLOAT4),
692 SeaType::DoublePrecision => Ok(PgType::FLOAT8),
693 SeaType::Varchar(_) => Ok(PgType::VARCHAR),
694 SeaType::Char(_) => Ok(PgType::CHAR),
695 SeaType::Text => Ok(PgType::TEXT),
696 SeaType::Bytea => Ok(PgType::BYTEA),
697 SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP),
698 SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ),
699 SeaType::Date => Ok(PgType::DATE),
700 SeaType::Time(_) => Ok(PgType::TIME),
701 SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ),
702 SeaType::Interval(_) => Ok(PgType::INTERVAL),
703 SeaType::Boolean => Ok(PgType::BOOL),
704 SeaType::Point => Ok(PgType::POINT),
705 SeaType::Uuid => Ok(PgType::UUID),
706 SeaType::Json => Ok(PgType::JSON),
707 SeaType::JsonBinary => Ok(PgType::JSONB),
708 SeaType::Array(t) => {
709 let Some(t) = t.col_type.as_ref() else {
710 bail!("missing array type")
711 };
712 match t.as_ref() {
713 SeaType::SmallInt => Ok(PgType::INT2_ARRAY),
715 SeaType::Integer => Ok(PgType::INT4_ARRAY),
716 SeaType::BigInt => Ok(PgType::INT8_ARRAY),
717 SeaType::Decimal(_) => Ok(PgType::NUMERIC_ARRAY),
718 SeaType::Numeric(_) => Ok(PgType::NUMERIC_ARRAY),
719 SeaType::Real => Ok(PgType::FLOAT4_ARRAY),
720 SeaType::DoublePrecision => Ok(PgType::FLOAT8_ARRAY),
721 SeaType::Varchar(_) => Ok(PgType::VARCHAR_ARRAY),
722 SeaType::Char(_) => Ok(PgType::CHAR_ARRAY),
723 SeaType::Text => Ok(PgType::TEXT_ARRAY),
724 SeaType::Bytea => Ok(PgType::BYTEA_ARRAY),
725 SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP_ARRAY),
726 SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ_ARRAY),
727 SeaType::Date => Ok(PgType::DATE_ARRAY),
728 SeaType::Time(_) => Ok(PgType::TIME_ARRAY),
729 SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ_ARRAY),
730 SeaType::Interval(_) => Ok(PgType::INTERVAL_ARRAY),
731 SeaType::Boolean => Ok(PgType::BOOL_ARRAY),
732 SeaType::Point => Ok(PgType::POINT_ARRAY),
733 SeaType::Uuid => Ok(PgType::UUID_ARRAY),
734 SeaType::Json => Ok(PgType::JSON_ARRAY),
735 SeaType::JsonBinary => Ok(PgType::JSONB_ARRAY),
736 SeaType::Array(_) => bail!("nested array type is not supported"),
737 SeaType::Unknown(name) => {
738 Ok(PgType::new(
740 name.clone(),
741 0,
742 PgKind::Array(PgType::new(
743 name.clone(),
744 0,
745 PgKind::Enum(vec![]),
746 "".into(),
747 )),
748 "".into(),
749 ))
750 }
751 _ => bail!("unsupported array type: {:?}", t),
752 }
753 }
754 SeaType::Unknown(name) => {
755 Ok(PgType::new(
757 name.clone(),
758 0,
759 PgKind::Enum(vec![]),
760 "".into(),
761 ))
762 }
763 _ => bail!("unsupported type: {:?}", sea_type),
764 }
765}
766
767#[cfg(test)]
768mod tests {
769 use super::parse_pgvector_dimension;
770
771 #[test]
772 fn test_parse_pgvector_dimension() {
773 assert_eq!(parse_pgvector_dimension("vector(3)").unwrap(), Some(3));
774 assert_eq!(parse_pgvector_dimension("VECTOR(768)").unwrap(), Some(768));
775 assert_eq!(parse_pgvector_dimension("varchar").unwrap(), None);
776 }
777
778 #[test]
779 fn test_parse_pgvector_dimension_requires_size() {
780 let err = parse_pgvector_dimension("vector").unwrap_err();
781 assert!(err.to_string().contains("missing dimension"));
782 }
783}