1use std::collections::{BTreeMap, HashMap};
16use std::fmt;
17
18use anyhow::{Context, anyhow};
19use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
20use postgres_openssl::MakeTlsConnector;
21use risingwave_common::bail;
22use risingwave_common::catalog::{ColumnDesc, ColumnId};
23use risingwave_common::types::{DataType, ScalarImpl, StructType};
24use sea_schema::postgres::def::{ColumnType as SeaType, TableDef, TableInfo};
25use sea_schema::postgres::discovery::SchemaDiscovery;
26use sea_schema::sea_query::{Alias, IntoIden};
27use serde::Deserialize;
28use sqlx::postgres::{PgConnectOptions, PgSslMode};
29use sqlx::{PgPool, Row};
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Kind as PgKind;
32use tokio_postgres::{Client as PgClient, NoTls};
33
34#[cfg(not(madsim))]
35use super::maybe_tls_connector::MaybeMakeTlsConnector;
36use crate::error::ConnectorResult;
37use crate::sink::postgres::TcpKeepaliveConfig;
38
39const DISCOVER_PRIMARY_KEY_QUERY: &str = r#"
42 SELECT a.attname as column_name
43 FROM pg_index i
44 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
45 WHERE i.indrelid = ($1 || '.' || $2)::regclass
46 AND i.indisprimary = true
47 ORDER BY array_position(i.indkey, a.attnum)
48"#;
49
50const DISCOVER_PGVECTOR_COLUMNS_QUERY: &str = r#"
54 SELECT
55 a.attname as column_name,
56 a.atttypmod as atttypmod,
57 format_type(a.atttypid, a.atttypmod) as formatted_type
58 FROM pg_attribute a
59 JOIN pg_class c ON c.oid = a.attrelid
60 JOIN pg_namespace n ON n.oid = c.relnamespace
61 JOIN pg_type t ON t.oid = a.atttypid
62 WHERE n.nspname = $1
63 AND c.relname = $2
64 AND t.typname = 'vector'
65 AND a.attnum > 0
66 AND NOT a.attisdropped
67 ORDER BY a.attnum
68"#;
69
70pub struct PgConnectionConfig {
71 pub host: String,
72 pub port: String,
73 pub user: String,
74 pub password: String,
75 pub database: String,
76 pub ssl_mode: SslMode,
77 pub ssl_root_cert: Option<String>,
78}
79
80pub fn pg_connection_config_from_properties(
81 props: &BTreeMap<String, String>,
82) -> ConnectorResult<PgConnectionConfig> {
83 Ok(PgConnectionConfig {
84 host: props
85 .get("hostname")
86 .context("missing `hostname` in postgres-cdc properties")?
87 .clone(),
88 port: props
89 .get("port")
90 .context("missing `port` in postgres-cdc properties")?
91 .clone(),
92 user: props
93 .get("username")
94 .context("missing `username` in postgres-cdc properties")?
95 .clone(),
96 password: props.get("password").cloned().unwrap_or_default(),
97 database: props
98 .get("database.name")
99 .context("missing `database.name` in postgres-cdc properties")?
100 .clone(),
101 ssl_mode: props
102 .get("ssl.mode")
103 .and_then(|v| v.parse::<SslMode>().ok())
104 .unwrap_or_default(),
105 ssl_root_cert: props.get("ssl.root.cert").cloned(),
106 })
107}
108
109pub async fn create_pg_client_from_properties(
110 props: &BTreeMap<String, String>,
111 tcp_keepalive: Option<TcpKeepaliveConfig>,
112) -> ConnectorResult<PgClient> {
113 let config = pg_connection_config_from_properties(props)?;
114 create_pg_client(
115 &config.user,
116 &config.password,
117 &config.host,
118 &config.port,
119 &config.database,
120 &config.ssl_mode,
121 &config.ssl_root_cert,
122 tcp_keepalive,
123 )
124 .await
125 .map_err(Into::into)
126}
127
128pub async fn discover_pgvector_dimensions(
129 client: &PgClient,
130 schema: &str,
131 table: &str,
132) -> ConnectorResult<HashMap<String, usize>> {
133 let rows = client
134 .query(DISCOVER_PGVECTOR_COLUMNS_QUERY, &[&schema, &table])
135 .await?;
136
137 let mut dims = HashMap::new();
138 for row in rows {
139 let col_name: String = row.get("column_name");
140 let atttypmod: i32 = row.get("atttypmod");
141 if atttypmod > 0
142 && let Ok(dim) = usize::try_from(atttypmod)
143 {
144 dims.insert(col_name, dim);
145 }
146 }
147 Ok(dims)
148}
149
150#[derive(Debug, Clone, PartialEq, Deserialize, Default)]
151#[serde(rename_all = "lowercase")]
152pub enum SslMode {
153 #[serde(alias = "disable")]
154 Disabled,
155 #[serde(alias = "prefer")]
156 #[default]
157 Preferred,
158 #[serde(alias = "require")]
159 Required,
160 #[serde(alias = "verify-ca")]
163 VerifyCa,
164 #[serde(alias = "verify-full")]
167 VerifyFull,
168}
169
170pub struct PostgresExternalTable {
171 column_descs: Vec<ColumnDesc>,
172 pk_names: Vec<String>,
173}
174
175impl PostgresExternalTable {
176 async fn discover_primary_key(
179 connection: &PgPool,
180 schema_name: &str,
181 table_name: &str,
182 ) -> ConnectorResult<Vec<String>> {
183 let rows = sqlx::query(DISCOVER_PRIMARY_KEY_QUERY)
184 .bind(schema_name)
185 .bind(table_name)
186 .fetch_all(connection)
187 .await
188 .context("Failed to discover primary key columns")?;
189
190 let pk_columns = rows
191 .into_iter()
192 .map(|row| row.get::<String, _>("column_name"))
193 .collect();
194
195 Ok(pk_columns)
196 }
197
198 async fn discover_pk_and_full_columns(
202 username: &str,
203 password: &str,
204 host: &str,
205 port: u16,
206 database: &str,
207 schema: &str,
208 table: &str,
209 ssl_mode: &SslMode,
210 ssl_root_cert: &Option<String>,
211 ) -> ConnectorResult<(Vec<sea_schema::postgres::def::ColumnInfo>, Vec<String>)> {
212 let mut options = PgConnectOptions::new()
213 .username(username)
214 .password(password)
215 .host(host)
216 .port(port)
217 .database(database)
218 .ssl_mode(match ssl_mode {
219 SslMode::Disabled => PgSslMode::Disable,
220 SslMode::Preferred => PgSslMode::Prefer,
221 SslMode::Required => PgSslMode::Require,
222 SslMode::VerifyCa => PgSslMode::VerifyCa,
223 SslMode::VerifyFull => PgSslMode::VerifyFull,
224 });
225
226 if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
227 && let Some(root_cert) = ssl_root_cert
228 {
229 options = options.ssl_root_cert(root_cert.as_str());
230 }
231
232 let connection = PgPool::connect_with(options).await?;
233
234 let schema_discovery = SchemaDiscovery::new(connection.clone(), schema);
236 let empty_map: HashMap<String, Vec<String>> = HashMap::new();
237 let columns = schema_discovery
238 .discover_columns(
239 Alias::new(schema).into_iden(),
240 Alias::new(table).into_iden(),
241 &empty_map,
242 )
243 .await?;
244
245 let pgvector_columns = sqlx::query(DISCOVER_PGVECTOR_COLUMNS_QUERY)
246 .bind(schema)
247 .bind(table)
248 .fetch_all(&connection)
249 .await
250 .context("Failed to discover PostgreSQL pgvector columns")?;
251 let formatted_type_by_column: HashMap<String, String> = pgvector_columns
252 .into_iter()
253 .map(|row| {
254 (
255 row.get::<String, _>("column_name"),
256 row.get::<String, _>("formatted_type"),
257 )
258 })
259 .collect();
260
261 let mut columns = columns;
264 for col in &mut columns {
265 if let SeaType::Unknown(name) = &col.col_type
266 && name.eq_ignore_ascii_case("vector")
267 && let Some(formatted_type) = formatted_type_by_column.get(&col.name)
268 {
269 col.col_type = SeaType::Unknown(formatted_type.clone());
270 }
271 }
272
273 let pk_columns = Self::discover_primary_key(&connection, schema, table).await?;
275
276 Ok((columns, pk_columns))
277 }
278
279 async fn discover_schema(
280 username: &str,
281 password: &str,
282 host: &str,
283 port: u16,
284 database: &str,
285 schema: &str,
286 table: &str,
287 ssl_mode: &SslMode,
288 ssl_root_cert: &Option<String>,
289 ) -> ConnectorResult<TableDef> {
290 let mut options = PgConnectOptions::new()
291 .username(username)
292 .password(password)
293 .host(host)
294 .port(port)
295 .database(database)
296 .ssl_mode(match ssl_mode {
297 SslMode::Disabled => PgSslMode::Disable,
298 SslMode::Preferred => PgSslMode::Prefer,
299 SslMode::Required => PgSslMode::Require,
300 SslMode::VerifyCa => PgSslMode::VerifyCa,
301 SslMode::VerifyFull => PgSslMode::VerifyFull,
302 });
303
304 if (*ssl_mode == SslMode::VerifyCa || *ssl_mode == SslMode::VerifyFull)
305 && let Some(root_cert) = ssl_root_cert
306 {
307 options = options.ssl_root_cert(root_cert.as_str());
308 }
309
310 let connection = PgPool::connect_with(options).await?;
311 let schema_discovery = SchemaDiscovery::new(connection, schema);
312 let empty_map = HashMap::new();
314 let table_schema = schema_discovery
315 .discover_table(
316 TableInfo {
317 name: table.to_owned(),
318 of_type: None,
319 },
320 &empty_map,
321 )
322 .await?;
323 Ok(table_schema)
324 }
325
326 pub async fn connect(
327 username: &str,
328 password: &str,
329 host: &str,
330 port: u16,
331 database: &str,
332 schema: &str,
333 table: &str,
334 ssl_mode: &SslMode,
335 ssl_root_cert: &Option<String>,
336 is_append_only: bool,
337 ) -> ConnectorResult<Self> {
338 tracing::debug!("connect to postgres external table");
339
340 let (columns, pk_names) = Self::discover_pk_and_full_columns(
341 username,
342 password,
343 host,
344 port,
345 database,
346 schema,
347 table,
348 ssl_mode,
349 ssl_root_cert,
350 )
351 .await?;
352
353 let mut column_descs = vec![];
354 for col in &columns {
355 let rw_data_type = sea_type_to_rw_type(&col.col_type)?;
356 let column_desc = if let Some(ref default_expr) = col.default {
357 let val_text = default_expr
360 .0
361 .split("::")
362 .map(|s| s.trim_matches('\''))
363 .next()
364 .expect("default value expression");
365
366 match ScalarImpl::from_text(val_text, &rw_data_type) {
367 Ok(scalar) => ColumnDesc::named_with_default_value(
368 col.name.clone(),
369 ColumnId::placeholder(),
370 rw_data_type.clone(),
371 Some(scalar),
372 ),
373 Err(err) => {
374 tracing::warn!(error=%err.as_report(), "failed to parse postgres default value expression, only constant is supported");
375 ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
376 }
377 }
378 } else {
379 ColumnDesc::named(col.name.clone(), ColumnId::placeholder(), rw_data_type)
380 };
381 column_descs.push(column_desc);
382 }
383
384 if !is_append_only && pk_names.is_empty() {
386 return Err(anyhow!(
387 "Postgres table should define the primary key for non-append-only tables"
388 )
389 .into());
390 }
391
392 Ok(Self {
393 column_descs,
394 pk_names,
395 })
396 }
397
398 pub async fn type_mapping(
400 username: &str,
401 password: &str,
402 host: &str,
403 port: u16,
404 database: &str,
405 schema: &str,
406 table: &str,
407 ssl_mode: &SslMode,
408 ssl_root_cert: &Option<String>,
409 is_append_only: bool,
410 ) -> ConnectorResult<HashMap<String, tokio_postgres::types::Type>> {
411 tracing::debug!("connect to postgres external table to get type mapping");
412 let table_schema = Self::discover_schema(
413 username,
414 password,
415 host,
416 port,
417 database,
418 schema,
419 table,
420 ssl_mode,
421 ssl_root_cert,
422 )
423 .await?;
424 let mut column_name_to_pg_type = HashMap::new();
425 for col in &table_schema.columns {
426 let pg_type = sea_type_to_pg_type(&col.col_type)?;
427 column_name_to_pg_type.insert(col.name.clone(), pg_type);
428 }
429 if !is_append_only && table_schema.primary_key_constraints.is_empty() {
430 return Err(anyhow!(
431 "Postgres table should define the primary key for non-append-only tables"
432 )
433 .into());
434 }
435 Ok(column_name_to_pg_type)
436 }
437
438 pub fn column_descs(&self) -> &Vec<ColumnDesc> {
439 &self.column_descs
440 }
441
442 pub fn pk_names(&self) -> &Vec<String> {
443 &self.pk_names
444 }
445}
446
447impl fmt::Display for SslMode {
448 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449 f.write_str(match self {
450 SslMode::Disabled => "disabled",
451 SslMode::Preferred => "preferred",
452 SslMode::Required => "required",
453 SslMode::VerifyCa => "verify-ca",
454 SslMode::VerifyFull => "verify-full",
455 })
456 }
457}
458
459impl std::str::FromStr for SslMode {
460 type Err = serde_json::Error;
461
462 fn from_str(s: &str) -> Result<Self, Self::Err> {
463 serde_json::from_value(serde_json::Value::String(s.to_owned()))
464 }
465}
466
467pub async fn create_pg_client(
468 user: &str,
469 password: &str,
470 host: &str,
471 port: &str,
472 database: &str,
473 ssl_mode: &SslMode,
474 ssl_root_cert: &Option<String>,
475 tcp_keepalive: Option<TcpKeepaliveConfig>,
476) -> anyhow::Result<PgClient> {
477 let mut pg_config = tokio_postgres::Config::new();
478 pg_config
479 .user(user)
480 .password(password)
481 .host(host)
482 .port(port.parse::<u16>().unwrap())
483 .dbname(database);
484
485 if let Some(keepalive) = tcp_keepalive {
487 pg_config.keepalives(true);
488 pg_config.keepalives_idle(std::time::Duration::from_secs(
489 keepalive.tcp_keepalive_idle as u64,
490 ));
491 #[cfg(not(target_os = "windows"))]
492 {
493 pg_config.keepalives_interval(std::time::Duration::from_secs(
494 keepalive.tcp_keepalive_interval as u64,
495 ));
496 pg_config.keepalives_retries(keepalive.tcp_keepalive_count);
497 }
498 tracing::info!(
499 "TCP keepalive enabled: idle={}s, interval={}s, retries={}",
500 keepalive.tcp_keepalive_idle,
501 keepalive.tcp_keepalive_interval,
502 keepalive.tcp_keepalive_count
503 );
504 }
505
506 let (_verify_ca, verify_hostname) = match ssl_mode {
507 SslMode::VerifyCa => (true, false),
508 SslMode::VerifyFull => (true, true),
509 _ => (false, false),
510 };
511
512 #[cfg(not(madsim))]
513 let connector = match ssl_mode {
514 SslMode::Disabled => {
515 pg_config.ssl_mode(tokio_postgres::config::SslMode::Disable);
516 MaybeMakeTlsConnector::NoTls(NoTls)
517 }
518 SslMode::Preferred => {
519 pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
520 match SslConnector::builder(SslMethod::tls()) {
521 Ok(mut builder) => {
522 builder.set_verify(SslVerifyMode::NONE);
524 MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
525 }
526 Err(e) => {
527 tracing::warn!(error = %e.as_report(), "SSL connector error");
528 MaybeMakeTlsConnector::NoTls(NoTls)
529 }
530 }
531 }
532 SslMode::Required => {
533 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
534 let mut builder = SslConnector::builder(SslMethod::tls())?;
535 builder.set_verify(SslVerifyMode::NONE);
537 MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
538 }
539
540 SslMode::VerifyCa | SslMode::VerifyFull => {
541 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
542 let mut builder = SslConnector::builder(SslMethod::tls())?;
543 if let Some(ssl_root_cert) = ssl_root_cert {
544 builder.set_ca_file(ssl_root_cert).map_err(|e| {
545 anyhow!(format!("bad ssl root cert error: {}", e.to_report_string()))
546 })?;
547 }
548 let mut connector = MakeTlsConnector::new(builder.build());
549 if !verify_hostname {
550 connector.set_callback(|config, _| {
551 config.set_verify_hostname(false);
552 Ok(())
553 });
554 }
555 MaybeMakeTlsConnector::Tls(connector)
556 }
557 };
558 #[cfg(madsim)]
559 let connector = NoTls;
560
561 let (client, connection) = pg_config.connect(connector).await?;
562
563 tokio::spawn(async move {
564 if let Err(e) = connection.await {
565 tracing::error!(error = %e.as_report(), "postgres connection error");
566 }
567 });
568
569 Ok(client)
570}
571
572pub fn sea_type_to_rw_type(col_type: &SeaType) -> ConnectorResult<DataType> {
574 let dtype = match col_type {
575 SeaType::SmallInt | SeaType::SmallSerial => DataType::Int16,
576 SeaType::Integer | SeaType::Serial => DataType::Int32,
577 SeaType::BigInt | SeaType::BigSerial => DataType::Int64,
578 SeaType::Money | SeaType::Decimal(_) | SeaType::Numeric(_) => DataType::Decimal,
579 SeaType::Real => DataType::Float32,
580 SeaType::DoublePrecision => DataType::Float64,
581 SeaType::Varchar(_) | SeaType::Char(_) | SeaType::Text => DataType::Varchar,
582 SeaType::Bytea => DataType::Bytea,
583 SeaType::Timestamp(_) => DataType::Timestamp,
584 SeaType::TimestampWithTimeZone(_) => DataType::Timestamptz,
585 SeaType::Date => DataType::Date,
586 SeaType::Time(_) | SeaType::TimeWithTimeZone(_) => DataType::Time,
587 SeaType::Interval(_) => DataType::Interval,
588 SeaType::Boolean => DataType::Boolean,
589 SeaType::Point => DataType::Struct(StructType::new(vec![
590 ("x", DataType::Float32),
591 ("y", DataType::Float32),
592 ])),
593 SeaType::Uuid => DataType::Varchar,
594 SeaType::Xml => DataType::Varchar,
595 SeaType::Json => DataType::Jsonb,
596 SeaType::JsonBinary => DataType::Jsonb,
597 SeaType::Array(def) => {
598 let item_type = match def.col_type.as_ref() {
599 Some(ty) => sea_type_to_rw_type(ty.as_ref())?,
600 None => {
601 return Err(anyhow!("ARRAY type missing element type").into());
602 }
603 };
604
605 DataType::list(item_type)
606 }
607 SeaType::PgLsn => DataType::Int64,
608 SeaType::Cidr
609 | SeaType::Inet
610 | SeaType::MacAddr
611 | SeaType::MacAddr8
612 | SeaType::Int4Range
613 | SeaType::Int8Range
614 | SeaType::NumRange
615 | SeaType::TsRange
616 | SeaType::TsTzRange
617 | SeaType::DateRange
618 | SeaType::Enum(_) => DataType::Varchar,
619 SeaType::Line
620 | SeaType::Lseg
621 | SeaType::Box
622 | SeaType::Path
623 | SeaType::Polygon
624 | SeaType::Circle
625 | SeaType::Bit(_)
626 | SeaType::VarBit(_)
627 | SeaType::TsVector
628 | SeaType::TsQuery => {
629 bail!("{:?} type not supported", col_type);
630 }
631 SeaType::Unknown(name) => {
632 if let Some(dim) = parse_pgvector_dimension(name)? {
633 DataType::Vector(dim)
634 } else {
635 tracing::warn!("Unknown Postgres data type: {name}, map to varchar");
637 DataType::Varchar
638 }
639 }
640 };
641
642 Ok(dtype)
643}
644
645fn parse_pgvector_dimension(type_name: &str) -> ConnectorResult<Option<usize>> {
646 let normalized = type_name.trim().to_ascii_lowercase();
647 if normalized == "vector" {
648 bail!("pgvector type `vector` is missing dimension, expected `vector(n)`")
649 }
650 if !normalized.starts_with("vector(") || !normalized.ends_with(')') {
651 return Ok(None);
652 }
653
654 let dim_text = normalized
655 .trim_start_matches("vector(")
656 .trim_end_matches(')')
657 .trim();
658 let dim = dim_text
659 .parse::<usize>()
660 .map_err(|_| anyhow!("invalid pgvector dimension in type `{type_name}`"))?;
661
662 if !(1..=DataType::VEC_MAX_SIZE).contains(&dim) {
663 bail!(
664 "pgvector dimension out of range in type `{}`: expect 1..={}",
665 type_name,
666 DataType::VEC_MAX_SIZE
667 );
668 }
669
670 Ok(Some(dim))
671}
672
673fn sea_type_to_pg_type(sea_type: &SeaType) -> ConnectorResult<tokio_postgres::types::Type> {
678 use tokio_postgres::types::Type as PgType;
679 match sea_type {
680 SeaType::SmallInt => Ok(PgType::INT2),
681 SeaType::Integer => Ok(PgType::INT4),
682 SeaType::BigInt => Ok(PgType::INT8),
683 SeaType::Decimal(_) => Ok(PgType::NUMERIC),
684 SeaType::Numeric(_) => Ok(PgType::NUMERIC),
685 SeaType::Real => Ok(PgType::FLOAT4),
686 SeaType::DoublePrecision => Ok(PgType::FLOAT8),
687 SeaType::Varchar(_) => Ok(PgType::VARCHAR),
688 SeaType::Char(_) => Ok(PgType::CHAR),
689 SeaType::Text => Ok(PgType::TEXT),
690 SeaType::Bytea => Ok(PgType::BYTEA),
691 SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP),
692 SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ),
693 SeaType::Date => Ok(PgType::DATE),
694 SeaType::Time(_) => Ok(PgType::TIME),
695 SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ),
696 SeaType::Interval(_) => Ok(PgType::INTERVAL),
697 SeaType::Boolean => Ok(PgType::BOOL),
698 SeaType::Point => Ok(PgType::POINT),
699 SeaType::Uuid => Ok(PgType::UUID),
700 SeaType::Json => Ok(PgType::JSON),
701 SeaType::JsonBinary => Ok(PgType::JSONB),
702 SeaType::Array(t) => {
703 let Some(t) = t.col_type.as_ref() else {
704 bail!("missing array type")
705 };
706 match t.as_ref() {
707 SeaType::SmallInt => Ok(PgType::INT2_ARRAY),
709 SeaType::Integer => Ok(PgType::INT4_ARRAY),
710 SeaType::BigInt => Ok(PgType::INT8_ARRAY),
711 SeaType::Decimal(_) => Ok(PgType::NUMERIC_ARRAY),
712 SeaType::Numeric(_) => Ok(PgType::NUMERIC_ARRAY),
713 SeaType::Real => Ok(PgType::FLOAT4_ARRAY),
714 SeaType::DoublePrecision => Ok(PgType::FLOAT8_ARRAY),
715 SeaType::Varchar(_) => Ok(PgType::VARCHAR_ARRAY),
716 SeaType::Char(_) => Ok(PgType::CHAR_ARRAY),
717 SeaType::Text => Ok(PgType::TEXT_ARRAY),
718 SeaType::Bytea => Ok(PgType::BYTEA_ARRAY),
719 SeaType::Timestamp(_) => Ok(PgType::TIMESTAMP_ARRAY),
720 SeaType::TimestampWithTimeZone(_) => Ok(PgType::TIMESTAMPTZ_ARRAY),
721 SeaType::Date => Ok(PgType::DATE_ARRAY),
722 SeaType::Time(_) => Ok(PgType::TIME_ARRAY),
723 SeaType::TimeWithTimeZone(_) => Ok(PgType::TIMETZ_ARRAY),
724 SeaType::Interval(_) => Ok(PgType::INTERVAL_ARRAY),
725 SeaType::Boolean => Ok(PgType::BOOL_ARRAY),
726 SeaType::Point => Ok(PgType::POINT_ARRAY),
727 SeaType::Uuid => Ok(PgType::UUID_ARRAY),
728 SeaType::Json => Ok(PgType::JSON_ARRAY),
729 SeaType::JsonBinary => Ok(PgType::JSONB_ARRAY),
730 SeaType::Array(_) => bail!("nested array type is not supported"),
731 SeaType::Unknown(name) => {
732 Ok(PgType::new(
734 name.clone(),
735 0,
736 PgKind::Array(PgType::new(
737 name.clone(),
738 0,
739 PgKind::Enum(vec![]),
740 "".into(),
741 )),
742 "".into(),
743 ))
744 }
745 _ => bail!("unsupported array type: {:?}", t),
746 }
747 }
748 SeaType::Unknown(name) => {
749 Ok(PgType::new(
751 name.clone(),
752 0,
753 PgKind::Enum(vec![]),
754 "".into(),
755 ))
756 }
757 _ => bail!("unsupported type: {:?}", sea_type),
758 }
759}
760
761#[cfg(test)]
762mod tests {
763 use super::parse_pgvector_dimension;
764
765 #[test]
766 fn test_parse_pgvector_dimension() {
767 assert_eq!(parse_pgvector_dimension("vector(3)").unwrap(), Some(3));
768 assert_eq!(parse_pgvector_dimension("VECTOR(768)").unwrap(), Some(768));
769 assert_eq!(parse_pgvector_dimension("varchar").unwrap(), None);
770 }
771
772 #[test]
773 fn test_parse_pgvector_dimension_requires_size() {
774 let err = parse_pgvector_dimension("vector").unwrap_err();
775 assert!(err.to_string().contains("missing dimension"));
776 }
777}