risingwave_connector/parser/
postgres.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::LazyLock;
16
17use bytes::Buf;
18use risingwave_common::array::Finite32;
19use risingwave_common::catalog::Schema;
20use risingwave_common::log::LogSuppressor;
21use risingwave_common::row::OwnedRow;
22use risingwave_common::types::{DataType, Decimal, ScalarImpl, VectorVal};
23use thiserror_ext::AsReport;
24use tokio_postgres::types::{FromSql, Type};
25
26use crate::parser::scalar_adapter::ScalarAdapter;
27use crate::parser::utils::log_error;
28
29static LOG_SUPPRESSOR: LazyLock<LogSuppressor> = LazyLock::new(LogSuppressor::default);
30
31/// Adapter for PostgreSQL `vector` type in CDC snapshot reads.
32/// It parses pgvector binary format.
33struct PgVectorAdapter(Vec<f32>);
34
35impl<'a> FromSql<'a> for PgVectorAdapter {
36    fn accepts(ty: &Type) -> bool {
37        ty.name() == "vector"
38    }
39
40    fn from_sql(
41        _ty: &Type,
42        raw: &'a [u8],
43    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
44        Self::parse_binary(raw)
45    }
46}
47
48impl PgVectorAdapter {
49    fn parse_binary(raw: &[u8]) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
50        // Binary format from pgvector extension:
51        // int16 dimension, int16 unused, repeated float4 values.
52        if raw.len() < 4 {
53            return Err("invalid vector binary payload".into());
54        }
55        let mut buf = raw;
56        let dim = buf.get_u16() as usize;
57        let _unused = buf.get_u16();
58        if buf.remaining() != dim * std::mem::size_of::<f32>() {
59            return Err("invalid vector binary payload length".into());
60        }
61        let mut elems = Vec::with_capacity(dim);
62        for _ in 0..dim {
63            elems.push(buf.get_f32());
64        }
65        Ok(Self(elems))
66    }
67}
68
69macro_rules! handle_data_type {
70    ($row:expr, $i:expr, $name:expr, $type:ty) => {{
71        let res = $row.try_get::<_, Option<$type>>($i);
72        match res {
73            Ok(val) => val.map(|v| ScalarImpl::from(v)),
74            Err(err) => {
75                log_error!($name, err, "parse column failed");
76                None
77            }
78        }
79    }};
80}
81
82pub fn postgres_row_to_owned_row(row: tokio_postgres::Row, schema: &Schema) -> OwnedRow {
83    let mut datums = vec![];
84    for i in 0..schema.fields.len() {
85        let rw_field = &schema.fields[i];
86        let name = rw_field.name.as_str();
87        let datum = postgres_cell_to_scalar_impl(&row, &rw_field.data_type, i, name);
88        datums.push(datum);
89    }
90    OwnedRow::new(datums)
91}
92
93pub fn postgres_cell_to_scalar_impl(
94    row: &tokio_postgres::Row,
95    data_type: &DataType,
96    i: usize,
97    name: &str,
98) -> Option<ScalarImpl> {
99    // We observe several incompatibility issue in Debezium's Postgres connector. We summarize them here:
100    // Issue #1. The null of enum list is not supported in Debezium. An enum list contains `NULL` will fallback to `NULL`.
101    // Issue #2. In our parser, when there's inf, -inf, nan or invalid item in a list, the whole list will fallback null.
102    match data_type {
103        DataType::Boolean
104        | DataType::Int16
105        | DataType::Int32
106        | DataType::Int64
107        | DataType::Float32
108        | DataType::Float64
109        | DataType::Date
110        | DataType::Time
111        | DataType::Timestamp
112        | DataType::Timestamptz
113        | DataType::Jsonb
114        | DataType::Interval
115        | DataType::Bytea => {
116            // ScalarAdapter is also fine. But ScalarImpl is more efficient
117            let res = row.try_get::<_, Option<ScalarImpl>>(i);
118            match res {
119                Ok(val) => val,
120                Err(err) => {
121                    log_error!(name, err, "parse column failed");
122                    None
123                }
124            }
125        }
126        DataType::Decimal => {
127            // Decimal is more efficient than PgNumeric in ScalarAdapter
128            handle_data_type!(row, i, name, Decimal)
129        }
130        DataType::Varchar | DataType::Int256 => {
131            let res = row.try_get::<_, Option<ScalarAdapter>>(i);
132            match res {
133                Ok(val) => val.and_then(|v| v.into_scalar(data_type)),
134                Err(err) => {
135                    log_error!(name, err, "parse column failed");
136                    None
137                }
138            }
139        }
140        DataType::Vector(expected_size) => {
141            let res = row.try_get::<_, Option<PgVectorAdapter>>(i);
142            match res {
143                Ok(Some(PgVectorAdapter(v))) => {
144                    if v.len() != *expected_size {
145                        log_error!(
146                            name,
147                            anyhow::anyhow!(
148                                "vector dimension mismatch: expected {}, got {}",
149                                expected_size,
150                                v.len()
151                            ),
152                            "parse column failed"
153                        );
154                        return None;
155                    }
156                    let finite = v
157                        .into_iter()
158                        .map(Finite32::try_from)
159                        .collect::<Result<Vec<_>, _>>();
160                    match finite {
161                        Ok(finite) => Some(ScalarImpl::Vector(VectorVal::from(finite))),
162                        Err(err) => {
163                            log_error!(name, anyhow::anyhow!(err), "parse column failed");
164                            None
165                        }
166                    }
167                }
168                Ok(None) => None,
169                Err(err) => {
170                    log_error!(name, err, "parse column failed");
171                    None
172                }
173            }
174        }
175        DataType::List(list) => match list.elem() {
176            // TODO(Kexiang): allow DataType::List(_)
177            elem @ (DataType::Struct(_) | DataType::List(_) | DataType::Serial) => {
178                tracing::warn!(
179                    "unsupported List data type {:?}, set the List to empty",
180                    elem
181                );
182                None
183            }
184            _ => {
185                let res = row.try_get::<_, Option<ScalarAdapter>>(i);
186                match res {
187                    Ok(val) => val.and_then(|v| v.into_scalar(data_type)),
188                    Err(err) => {
189                        log_error!(name, err, "parse list column failed");
190                        None
191                    }
192                }
193            }
194        },
195        DataType::Struct(_) | DataType::Serial | DataType::Map(_) => {
196            // Is this branch reachable?
197            // Struct and Serial are not supported
198            tracing::warn!(name, ?data_type, "unsupported data type, set to null");
199            None
200        }
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use tokio_postgres::NoTls;
207
208    use crate::parser::postgres::PgVectorAdapter;
209    use crate::parser::scalar_adapter::EnumString;
210    const DB: &str = "postgres";
211    const USER: &str = "kexiang";
212
213    #[test]
214    fn test_pg_vector_adapter_parse_binary() {
215        let mut raw = vec![];
216        // dim = 3
217        raw.extend_from_slice(&(3u16.to_be_bytes()));
218        // unused
219        raw.extend_from_slice(&(0u16.to_be_bytes()));
220        raw.extend_from_slice(&1.5f32.to_be_bytes());
221        raw.extend_from_slice(&(-2.25f32).to_be_bytes());
222        raw.extend_from_slice(&3.0f32.to_be_bytes());
223
224        let v = PgVectorAdapter::parse_binary(&raw).unwrap();
225        assert_eq!(v.0, vec![1.5, -2.25, 3.0]);
226    }
227
228    #[ignore]
229    #[tokio::test]
230    async fn enum_string_integration_test() {
231        let connect = format!(
232            "host=localhost port=5432 user={} password={} dbname={}",
233            USER, DB, DB
234        );
235        let (client, connection) = tokio_postgres::connect(connect.as_str(), NoTls)
236            .await
237            .unwrap();
238
239        // The connection object performs the actual communication with the database,
240        // so spawn it off to run on its own.
241        tokio::spawn(async move {
242            if let Err(e) = connection.await {
243                eprintln!("connection error: {}", e);
244            }
245        });
246
247        // allow type existed
248        let _ = client
249            .execute("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')", &[])
250            .await;
251        client
252            .execute(
253                "CREATE TABLE IF NOT EXISTS person(id int PRIMARY KEY, current_mood mood)",
254                &[],
255            )
256            .await
257            .unwrap();
258        client.execute("DELETE FROM person;", &[]).await.unwrap();
259        client
260            .execute("INSERT INTO person VALUES (1, 'happy')", &[])
261            .await
262            .unwrap();
263
264        // test from_sql
265        let got: EnumString = client
266            .query_one("SELECT * FROM person", &[])
267            .await
268            .unwrap()
269            .get::<usize, Option<EnumString>>(1)
270            .unwrap();
271        assert_eq!("happy", got.0.as_str());
272
273        client.execute("DELETE FROM person", &[]).await.unwrap();
274
275        // test to_sql
276        client
277            .execute("INSERT INTO person VALUES (2, $1)", &[&got])
278            .await
279            .unwrap();
280
281        let got_new: EnumString = client
282            .query_one("SELECT * FROM person", &[])
283            .await
284            .unwrap()
285            .get::<usize, Option<EnumString>>(1)
286            .unwrap();
287        assert_eq!("happy", got_new.0.as_str());
288        client.execute("DROP TABLE person", &[]).await.unwrap();
289        client.execute("DROP TYPE mood", &[]).await.unwrap();
290    }
291}