risingwave_connector/parser/
postgres.rs1use std::sync::LazyLock;
16
17use bytes::Buf;
18use risingwave_common::array::Finite32;
19use risingwave_common::catalog::Schema;
20use risingwave_common::log::LogSuppressor;
21use risingwave_common::row::OwnedRow;
22use risingwave_common::types::{DataType, Decimal, ScalarImpl, VectorVal};
23use thiserror_ext::AsReport;
24use tokio_postgres::types::{FromSql, Type};
25
26use crate::parser::scalar_adapter::ScalarAdapter;
27use crate::parser::utils::log_error;
28
29static LOG_SUPPRESSOR: LazyLock<LogSuppressor> = LazyLock::new(LogSuppressor::default);
30
31struct PgVectorAdapter(Vec<f32>);
34
35impl<'a> FromSql<'a> for PgVectorAdapter {
36 fn accepts(ty: &Type) -> bool {
37 ty.name() == "vector"
38 }
39
40 fn from_sql(
41 _ty: &Type,
42 raw: &'a [u8],
43 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
44 Self::parse_binary(raw)
45 }
46}
47
48impl PgVectorAdapter {
49 fn parse_binary(raw: &[u8]) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
50 if raw.len() < 4 {
53 return Err("invalid vector binary payload".into());
54 }
55 let mut buf = raw;
56 let dim = buf.get_u16() as usize;
57 let _unused = buf.get_u16();
58 if buf.remaining() != dim * std::mem::size_of::<f32>() {
59 return Err("invalid vector binary payload length".into());
60 }
61 let mut elems = Vec::with_capacity(dim);
62 for _ in 0..dim {
63 elems.push(buf.get_f32());
64 }
65 Ok(Self(elems))
66 }
67}
68
69macro_rules! handle_data_type {
70 ($row:expr, $i:expr, $name:expr, $type:ty) => {{
71 let res = $row.try_get::<_, Option<$type>>($i);
72 match res {
73 Ok(val) => val.map(|v| ScalarImpl::from(v)),
74 Err(err) => {
75 log_error!($name, err, "parse column failed");
76 None
77 }
78 }
79 }};
80}
81
82pub fn postgres_row_to_owned_row(row: tokio_postgres::Row, schema: &Schema) -> OwnedRow {
83 let mut datums = vec![];
84 for i in 0..schema.fields.len() {
85 let rw_field = &schema.fields[i];
86 let name = rw_field.name.as_str();
87 let datum = postgres_cell_to_scalar_impl(&row, &rw_field.data_type, i, name);
88 datums.push(datum);
89 }
90 OwnedRow::new(datums)
91}
92
93pub fn postgres_cell_to_scalar_impl(
94 row: &tokio_postgres::Row,
95 data_type: &DataType,
96 i: usize,
97 name: &str,
98) -> Option<ScalarImpl> {
99 match data_type {
103 DataType::Boolean
104 | DataType::Int16
105 | DataType::Int32
106 | DataType::Int64
107 | DataType::Float32
108 | DataType::Float64
109 | DataType::Date
110 | DataType::Time
111 | DataType::Timestamp
112 | DataType::Timestamptz
113 | DataType::Jsonb
114 | DataType::Interval
115 | DataType::Bytea => {
116 let res = row.try_get::<_, Option<ScalarImpl>>(i);
118 match res {
119 Ok(val) => val,
120 Err(err) => {
121 log_error!(name, err, "parse column failed");
122 None
123 }
124 }
125 }
126 DataType::Decimal => {
127 handle_data_type!(row, i, name, Decimal)
129 }
130 DataType::Varchar | DataType::Int256 => {
131 let res = row.try_get::<_, Option<ScalarAdapter>>(i);
132 match res {
133 Ok(val) => val.and_then(|v| v.into_scalar(data_type)),
134 Err(err) => {
135 log_error!(name, err, "parse column failed");
136 None
137 }
138 }
139 }
140 DataType::Vector(expected_size) => {
141 let res = row.try_get::<_, Option<PgVectorAdapter>>(i);
142 match res {
143 Ok(Some(PgVectorAdapter(v))) => {
144 if v.len() != *expected_size {
145 log_error!(
146 name,
147 anyhow::anyhow!(
148 "vector dimension mismatch: expected {}, got {}",
149 expected_size,
150 v.len()
151 ),
152 "parse column failed"
153 );
154 return None;
155 }
156 let finite = v
157 .into_iter()
158 .map(Finite32::try_from)
159 .collect::<Result<Vec<_>, _>>();
160 match finite {
161 Ok(finite) => Some(ScalarImpl::Vector(VectorVal::from(finite))),
162 Err(err) => {
163 log_error!(name, anyhow::anyhow!(err), "parse column failed");
164 None
165 }
166 }
167 }
168 Ok(None) => None,
169 Err(err) => {
170 log_error!(name, err, "parse column failed");
171 None
172 }
173 }
174 }
175 DataType::List(list) => match list.elem() {
176 elem @ (DataType::Struct(_) | DataType::List(_) | DataType::Serial) => {
178 tracing::warn!(
179 "unsupported List data type {:?}, set the List to empty",
180 elem
181 );
182 None
183 }
184 _ => {
185 let res = row.try_get::<_, Option<ScalarAdapter>>(i);
186 match res {
187 Ok(val) => val.and_then(|v| v.into_scalar(data_type)),
188 Err(err) => {
189 log_error!(name, err, "parse list column failed");
190 None
191 }
192 }
193 }
194 },
195 DataType::Struct(_) | DataType::Serial | DataType::Map(_) => {
196 tracing::warn!(name, ?data_type, "unsupported data type, set to null");
199 None
200 }
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use tokio_postgres::NoTls;
207
208 use crate::parser::postgres::PgVectorAdapter;
209 use crate::parser::scalar_adapter::EnumString;
210 const DB: &str = "postgres";
211 const USER: &str = "kexiang";
212
213 #[test]
214 fn test_pg_vector_adapter_parse_binary() {
215 let mut raw = vec![];
216 raw.extend_from_slice(&(3u16.to_be_bytes()));
218 raw.extend_from_slice(&(0u16.to_be_bytes()));
220 raw.extend_from_slice(&1.5f32.to_be_bytes());
221 raw.extend_from_slice(&(-2.25f32).to_be_bytes());
222 raw.extend_from_slice(&3.0f32.to_be_bytes());
223
224 let v = PgVectorAdapter::parse_binary(&raw).unwrap();
225 assert_eq!(v.0, vec![1.5, -2.25, 3.0]);
226 }
227
228 #[ignore]
229 #[tokio::test]
230 async fn enum_string_integration_test() {
231 let connect = format!(
232 "host=localhost port=5432 user={} password={} dbname={}",
233 USER, DB, DB
234 );
235 let (client, connection) = tokio_postgres::connect(connect.as_str(), NoTls)
236 .await
237 .unwrap();
238
239 tokio::spawn(async move {
242 if let Err(e) = connection.await {
243 eprintln!("connection error: {}", e);
244 }
245 });
246
247 let _ = client
249 .execute("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')", &[])
250 .await;
251 client
252 .execute(
253 "CREATE TABLE IF NOT EXISTS person(id int PRIMARY KEY, current_mood mood)",
254 &[],
255 )
256 .await
257 .unwrap();
258 client.execute("DELETE FROM person;", &[]).await.unwrap();
259 client
260 .execute("INSERT INTO person VALUES (1, 'happy')", &[])
261 .await
262 .unwrap();
263
264 let got: EnumString = client
266 .query_one("SELECT * FROM person", &[])
267 .await
268 .unwrap()
269 .get::<usize, Option<EnumString>>(1)
270 .unwrap();
271 assert_eq!("happy", got.0.as_str());
272
273 client.execute("DELETE FROM person", &[]).await.unwrap();
274
275 client
277 .execute("INSERT INTO person VALUES (2, $1)", &[&got])
278 .await
279 .unwrap();
280
281 let got_new: EnumString = client
282 .query_one("SELECT * FROM person", &[])
283 .await
284 .unwrap()
285 .get::<usize, Option<EnumString>>(1)
286 .unwrap();
287 assert_eq!("happy", got_new.0.as_str());
288 client.execute("DROP TABLE person", &[]).await.unwrap();
289 client.execute("DROP TYPE mood", &[]).await.unwrap();
290 }
291}