risingwave_connector/source/cdc/external/
postgres.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16
17use anyhow::Context;
18use futures::stream::BoxStream;
19use futures::{StreamExt, pin_mut};
20use futures_async_stream::try_stream;
21use itertools::Itertools;
22use risingwave_common::catalog::Schema;
23use risingwave_common::row::{OwnedRow, Row};
24use risingwave_common::util::iter_util::ZipEqFast;
25use serde_derive::{Deserialize, Serialize};
26use tokio_postgres::types::PgLsn;
27
28use crate::connector_common::create_pg_client;
29use crate::error::{ConnectorError, ConnectorResult};
30use crate::parser::postgres_row_to_owned_row;
31use crate::parser::scalar_adapter::ScalarAdapter;
32use crate::source::cdc::external::{
33    CdcOffset, CdcOffsetParseFunc, DebeziumOffset, ExternalTableConfig, ExternalTableReader,
34    SchemaTableName,
35};
36
37#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
38pub struct PostgresOffset {
39    pub txid: i64,
40    // In postgres, an LSN is a 64-bit integer, representing a byte position in the write-ahead log stream.
41    // It is printed as two hexadecimal numbers of up to 8 digits each, separated by a slash; for example, 16/B374D848
42    pub lsn: u64,
43}
44
45// only compare the lsn field
46impl PartialOrd for PostgresOffset {
47    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
48        self.lsn.partial_cmp(&other.lsn)
49    }
50}
51
52impl PostgresOffset {
53    pub fn parse_debezium_offset(offset: &str) -> ConnectorResult<Self> {
54        let dbz_offset: DebeziumOffset = serde_json::from_str(offset)
55            .with_context(|| format!("invalid upstream offset: {}", offset))?;
56
57        Ok(Self {
58            txid: dbz_offset
59                .source_offset
60                .txid
61                .context("invalid postgres txid")?,
62            lsn: dbz_offset
63                .source_offset
64                .lsn
65                .context("invalid postgres lsn")?,
66        })
67    }
68}
69
70pub struct PostgresExternalTableReader {
71    rw_schema: Schema,
72    field_names: String,
73    pk_indices: Vec<usize>,
74    client: tokio::sync::Mutex<tokio_postgres::Client>,
75}
76
77impl ExternalTableReader for PostgresExternalTableReader {
78    async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
79        let mut client = self.client.lock().await;
80        // start a transaction to read current lsn and txid
81        let trxn = client.transaction().await?;
82        let row = trxn.query_one("SELECT pg_current_wal_lsn()", &[]).await?;
83        let mut pg_offset = PostgresOffset::default();
84        let pg_lsn = row.get::<_, PgLsn>(0);
85        tracing::debug!("current lsn: {}", pg_lsn);
86        pg_offset.lsn = pg_lsn.into();
87
88        let txid_row = trxn.query_one("SELECT txid_current()", &[]).await?;
89        let txid: i64 = txid_row.get::<_, i64>(0);
90        pg_offset.txid = txid;
91
92        // commit the transaction
93        trxn.commit().await?;
94
95        Ok(CdcOffset::Postgres(pg_offset))
96    }
97
98    fn snapshot_read(
99        &self,
100        table_name: SchemaTableName,
101        start_pk: Option<OwnedRow>,
102        primary_keys: Vec<String>,
103        limit: u32,
104    ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
105        self.snapshot_read_inner(table_name, start_pk, primary_keys, limit)
106    }
107}
108
109impl PostgresExternalTableReader {
110    pub async fn new(
111        config: ExternalTableConfig,
112        rw_schema: Schema,
113        pk_indices: Vec<usize>,
114    ) -> ConnectorResult<Self> {
115        tracing::info!(
116            ?rw_schema,
117            ?pk_indices,
118            "create postgres external table reader"
119        );
120
121        let client = create_pg_client(
122            &config.username,
123            &config.password,
124            &config.host,
125            &config.port,
126            &config.database,
127            &config.ssl_mode,
128            &config.ssl_root_cert,
129        )
130        .await?;
131
132        let field_names = rw_schema
133            .fields
134            .iter()
135            .map(|f| Self::quote_column(&f.name))
136            .join(",");
137
138        Ok(Self {
139            rw_schema,
140            field_names,
141            pk_indices,
142            client: tokio::sync::Mutex::new(client),
143        })
144    }
145
146    pub fn get_normalized_table_name(table_name: &SchemaTableName) -> String {
147        format!(
148            "\"{}\".\"{}\"",
149            table_name.schema_name, table_name.table_name
150        )
151    }
152
153    pub fn get_cdc_offset_parser() -> CdcOffsetParseFunc {
154        Box::new(move |offset| {
155            Ok(CdcOffset::Postgres(PostgresOffset::parse_debezium_offset(
156                offset,
157            )?))
158        })
159    }
160
161    #[try_stream(boxed, ok = OwnedRow, error = ConnectorError)]
162    async fn snapshot_read_inner(
163        &self,
164        table_name: SchemaTableName,
165        start_pk_row: Option<OwnedRow>,
166        primary_keys: Vec<String>,
167        scan_limit: u32,
168    ) {
169        let order_key = Self::get_order_key(&primary_keys);
170        let client = self.client.lock().await;
171        client.execute("set time zone '+00:00'", &[]).await?;
172
173        let stream = match start_pk_row {
174            Some(ref pk_row) => {
175                // prepare the scan statement, since we may need to convert the RW data type to postgres data type
176                // e.g. varchar to uuid
177                let prepared_scan_stmt = {
178                    let primary_keys = self
179                        .pk_indices
180                        .iter()
181                        .map(|i| self.rw_schema.fields[*i].name.clone())
182                        .collect_vec();
183
184                    let order_key = Self::get_order_key(&primary_keys);
185                    let scan_sql = format!(
186                        "SELECT {} FROM {} WHERE {} ORDER BY {} LIMIT {scan_limit}",
187                        self.field_names,
188                        Self::get_normalized_table_name(&table_name),
189                        Self::filter_expression(&primary_keys),
190                        order_key,
191                    );
192                    client.prepare(&scan_sql).await?
193                };
194
195                let params: Vec<Option<ScalarAdapter>> = pk_row
196                    .iter()
197                    .zip_eq_fast(prepared_scan_stmt.params())
198                    .map(|(datum, ty)| {
199                        datum
200                            .map(|scalar| ScalarAdapter::from_scalar(scalar, ty))
201                            .transpose()
202                    })
203                    .try_collect()?;
204
205                client.query_raw(&prepared_scan_stmt, &params).await?
206            }
207            None => {
208                let sql = format!(
209                    "SELECT {} FROM {} ORDER BY {} LIMIT {scan_limit}",
210                    self.field_names,
211                    Self::get_normalized_table_name(&table_name),
212                    order_key,
213                );
214                let params: Vec<Option<ScalarAdapter>> = vec![];
215                client.query_raw(&sql, &params).await?
216            }
217        };
218
219        let row_stream = stream.map(|row| {
220            let row = row?;
221            Ok::<_, crate::error::ConnectorError>(postgres_row_to_owned_row(row, &self.rw_schema))
222        });
223
224        pin_mut!(row_stream);
225        #[for_await]
226        for row in row_stream {
227            let row = row?;
228            yield row;
229        }
230    }
231
232    // row filter expression: (v1, v2, v3) > ($1, $2, $3)
233    fn filter_expression(columns: &[String]) -> String {
234        let mut col_expr = String::new();
235        let mut arg_expr = String::new();
236        for (i, column) in columns.iter().enumerate() {
237            if i > 0 {
238                col_expr.push_str(", ");
239                arg_expr.push_str(", ");
240            }
241            col_expr.push_str(&Self::quote_column(column));
242            arg_expr.push_str(format!("${}", i + 1).as_str());
243        }
244        format!("({}) > ({})", col_expr, arg_expr)
245    }
246
247    fn get_order_key(primary_keys: &Vec<String>) -> String {
248        primary_keys
249            .iter()
250            .map(|col| Self::quote_column(col))
251            .join(",")
252    }
253
254    fn quote_column(column: &str) -> String {
255        format!("\"{}\"", column)
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use std::collections::HashMap;
262
263    use futures::pin_mut;
264    use futures_async_stream::for_await;
265    use maplit::{convert_args, hashmap};
266    use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
267    use risingwave_common::row::OwnedRow;
268    use risingwave_common::types::{DataType, ScalarImpl};
269
270    use crate::connector_common::PostgresExternalTable;
271    use crate::source::cdc::external::postgres::{PostgresExternalTableReader, PostgresOffset};
272    use crate::source::cdc::external::{ExternalTableConfig, ExternalTableReader, SchemaTableName};
273
274    #[ignore]
275    #[tokio::test]
276    async fn test_postgres_schema() {
277        let config = ExternalTableConfig {
278            connector: "postgres-cdc".to_owned(),
279            host: "localhost".to_owned(),
280            port: "8432".to_owned(),
281            username: "myuser".to_owned(),
282            password: "123456".to_owned(),
283            database: "mydb".to_owned(),
284            schema: "public".to_owned(),
285            table: "mytest".to_owned(),
286            ssl_mode: Default::default(),
287            ssl_root_cert: None,
288            encrypt: "false".to_owned(),
289        };
290
291        let table = PostgresExternalTable::connect(
292            &config.username,
293            &config.password,
294            &config.host,
295            config.port.parse::<u16>().unwrap(),
296            &config.database,
297            &config.schema,
298            &config.table,
299            &config.ssl_mode,
300            &config.ssl_root_cert,
301            false,
302        )
303        .await
304        .unwrap();
305
306        println!("columns: {:?}", &table.column_descs());
307        println!("primary keys: {:?}", &table.pk_names());
308    }
309
310    #[test]
311    fn test_postgres_offset() {
312        let off1 = PostgresOffset { txid: 4, lsn: 2 };
313        let off2 = PostgresOffset { txid: 1, lsn: 3 };
314        let off3 = PostgresOffset { txid: 5, lsn: 1 };
315
316        assert!(off1 < off2);
317        assert!(off3 < off1);
318        assert!(off2 > off3);
319    }
320
321    #[test]
322    fn test_filter_expression() {
323        let cols = vec!["v1".to_owned()];
324        let expr = PostgresExternalTableReader::filter_expression(&cols);
325        assert_eq!(expr, "(\"v1\") > ($1)");
326
327        let cols = vec!["v1".to_owned(), "v2".to_owned()];
328        let expr = PostgresExternalTableReader::filter_expression(&cols);
329        assert_eq!(expr, "(\"v1\", \"v2\") > ($1, $2)");
330
331        let cols = vec!["v1".to_owned(), "v2".to_owned(), "v3".to_owned()];
332        let expr = PostgresExternalTableReader::filter_expression(&cols);
333        assert_eq!(expr, "(\"v1\", \"v2\", \"v3\") > ($1, $2, $3)");
334    }
335
336    // manual test
337    #[ignore]
338    #[tokio::test]
339    async fn test_pg_table_reader() {
340        let columns = vec![
341            ColumnDesc::named("v1", ColumnId::new(1), DataType::Int32),
342            ColumnDesc::named("v2", ColumnId::new(2), DataType::Varchar),
343            ColumnDesc::named("v3", ColumnId::new(3), DataType::Decimal),
344            ColumnDesc::named("v4", ColumnId::new(4), DataType::Date),
345        ];
346        let rw_schema = Schema {
347            fields: columns.iter().map(Field::from).collect(),
348        };
349
350        let props: HashMap<String, String> = convert_args!(hashmap!(
351                "hostname" => "localhost",
352                "port" => "8432",
353                "username" => "myuser",
354                "password" => "123456",
355                "database.name" => "mydb",
356                "schema.name" => "public",
357                "table.name" => "t1"));
358
359        let config =
360            serde_json::from_value::<ExternalTableConfig>(serde_json::to_value(props).unwrap())
361                .unwrap();
362        let reader = PostgresExternalTableReader::new(config, rw_schema, vec![0, 1])
363            .await
364            .unwrap();
365
366        let offset = reader.current_cdc_offset().await.unwrap();
367        println!("CdcOffset: {:?}", offset);
368
369        let start_pk = OwnedRow::new(vec![Some(ScalarImpl::from(3)), Some(ScalarImpl::from("c"))]);
370        let stream = reader.snapshot_read(
371            SchemaTableName {
372                schema_name: "public".to_owned(),
373                table_name: "t1".to_owned(),
374            },
375            Some(start_pk),
376            vec!["v1".to_owned(), "v2".to_owned()],
377            1000,
378        );
379
380        pin_mut!(stream);
381        #[for_await]
382        for row in stream {
383            println!("OwnedRow: {:?}", row);
384        }
385    }
386}