risingwave_connector/source/cdc/external/
postgres.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::sync::LazyLock;
17
18use anyhow::Context;
19use futures::stream::BoxStream;
20use futures::{StreamExt, pin_mut};
21use futures_async_stream::{for_await, try_stream};
22use itertools::Itertools;
23use risingwave_common::catalog::{Field, Schema};
24use risingwave_common::log::LogSuppresser;
25use risingwave_common::row::{OwnedRow, Row};
26use risingwave_common::types::{DataType, Datum, ScalarImpl, ToOwnedDatum};
27use risingwave_common::util::iter_util::ZipEqFast;
28use serde::{Deserialize, Serialize};
29use tokio_postgres::types::{PgLsn, Type as PgType};
30
31use crate::connector_common::create_pg_client;
32use crate::error::{ConnectorError, ConnectorResult};
33use crate::parser::scalar_adapter::ScalarAdapter;
34use crate::parser::{postgres_cell_to_scalar_impl, postgres_row_to_owned_row};
35use crate::source::CdcTableSnapshotSplit;
36use crate::source::cdc::external::{
37    CDC_TABLE_SPLIT_ID_START, CdcOffset, CdcOffsetParseFunc, CdcTableSnapshotSplitOption,
38    DebeziumOffset, ExternalTableConfig, ExternalTableReader, SchemaTableName,
39};
40
41#[derive(Debug, Clone, Default, Serialize, Deserialize)]
42pub struct PostgresOffset {
43    pub txid: i64,
44    // In postgres, an LSN is a 64-bit integer, representing a byte position in the write-ahead log stream.
45    // It is printed as two hexadecimal numbers of up to 8 digits each, separated by a slash; for example, 16/B374D848
46    pub lsn: u64,
47    // Additional LSN fields for improved tracking
48    #[serde(default)]
49    pub lsn_commit: Option<u64>,
50    #[serde(default)]
51    pub lsn_proc: Option<u64>,
52}
53
54impl PartialOrd for PostgresOffset {
55    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
56        Some(self.cmp(other))
57    }
58}
59
60impl Eq for PostgresOffset {}
61impl PartialEq for PostgresOffset {
62    fn eq(&self, other: &Self) -> bool {
63        match (
64            self.lsn_commit,
65            self.lsn_proc,
66            other.lsn_commit,
67            other.lsn_proc,
68        ) {
69            (_, Some(_), _, Some(_)) => {
70                self.lsn_commit == other.lsn_commit && self.lsn_proc == other.lsn_proc
71            }
72            _ => self.lsn == other.lsn,
73        }
74    }
75}
76
77// only compare the lsn field, prefer lsn_commit and lsn_proc if both available
78impl Ord for PostgresOffset {
79    fn cmp(&self, other: &Self) -> Ordering {
80        match (
81            self.lsn_commit,
82            self.lsn_proc,
83            other.lsn_commit,
84            other.lsn_proc,
85        ) {
86            (_, Some(self_proc), _, Some(other_proc)) => {
87                // if both have `lsn_commit` and `lsn_proc`, compare `lsn_commit` first, then `lsn_proc`
88                // if `lsn_commit` is None, fall back to `lsn_proc`
89                match self.lsn_commit.cmp(&other.lsn_commit) {
90                    Ordering::Equal => self_proc.cmp(&other_proc),
91                    other_result => other_result,
92                }
93            }
94            _ => {
95                // Fall back to lsn comparison when either lsn_commit or lsn_proc is missing
96                static LOG_SUPPERSSER: LazyLock<LogSuppresser> =
97                    LazyLock::new(LogSuppresser::default);
98                if let Ok(suppressed_count) = LOG_SUPPERSSER.check() {
99                    tracing::warn!(
100                        suppressed_count,
101                        self_lsn = self.lsn,
102                        other_lsn = other.lsn,
103                        "lsn_commit and lsn_proc are missing, fall back to lsn comparison"
104                    );
105                }
106                self.lsn.cmp(&other.lsn)
107            }
108        }
109    }
110}
111
112impl PostgresOffset {
113    pub fn parse_debezium_offset(offset: &str) -> ConnectorResult<Self> {
114        let dbz_offset: DebeziumOffset = serde_json::from_str(offset)
115            .with_context(|| format!("invalid upstream offset: {}", offset))?;
116
117        let lsn = dbz_offset
118            .source_offset
119            .lsn
120            .context("invalid postgres lsn")?;
121
122        // `lsn_commit` may not be present in the offset for the first tx.
123        let lsn_commit = dbz_offset.source_offset.lsn_commit;
124
125        let lsn_proc = dbz_offset
126            .source_offset
127            .lsn_proc
128            .context("invalid postgres lsn_proc")?;
129
130        Ok(Self {
131            txid: dbz_offset
132                .source_offset
133                .txid
134                .context("invalid postgres txid")?,
135            lsn,
136            lsn_commit,
137            lsn_proc: Some(lsn_proc),
138        })
139    }
140}
141
142pub struct PostgresExternalTableReader {
143    rw_schema: Schema,
144    field_names: String,
145    pk_indices: Vec<usize>,
146    client: tokio::sync::Mutex<tokio_postgres::Client>,
147    schema_table_name: SchemaTableName,
148}
149
150impl ExternalTableReader for PostgresExternalTableReader {
151    async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
152        let mut client = self.client.lock().await;
153        // start a transaction to read current lsn and txid
154        let trxn = client.transaction().await?;
155        let row = trxn.query_one("SELECT pg_current_wal_lsn()", &[]).await?;
156        let mut pg_offset = PostgresOffset::default();
157        let pg_lsn = row.get::<_, PgLsn>(0);
158        tracing::debug!("current lsn: {}", pg_lsn);
159        pg_offset.lsn = pg_lsn.into();
160
161        let txid_row = trxn.query_one("SELECT txid_current()", &[]).await?;
162        let txid: i64 = txid_row.get::<_, i64>(0);
163        pg_offset.txid = txid;
164
165        // commit the transaction
166        trxn.commit().await?;
167
168        Ok(CdcOffset::Postgres(pg_offset))
169    }
170
171    fn snapshot_read(
172        &self,
173        table_name: SchemaTableName,
174        start_pk: Option<OwnedRow>,
175        primary_keys: Vec<String>,
176        limit: u32,
177    ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
178        assert_eq!(table_name, self.schema_table_name);
179        self.snapshot_read_inner(table_name, start_pk, primary_keys, limit)
180    }
181
182    #[try_stream(boxed, ok = CdcTableSnapshotSplit, error = ConnectorError)]
183    async fn get_parallel_cdc_splits(&self, options: CdcTableSnapshotSplitOption) {
184        let backfill_num_rows_per_split = options.backfill_num_rows_per_split;
185        if backfill_num_rows_per_split == 0 {
186            return Err(anyhow::anyhow!(
187                "invalid backfill_num_rows_per_split, must be greater than 0"
188            )
189            .into());
190        }
191        if options.backfill_split_pk_column_index as usize >= self.pk_indices.len() {
192            return Err(anyhow::anyhow!(format!(
193                "invalid backfill_split_pk_column_index {}, out of bound",
194                options.backfill_split_pk_column_index
195            ))
196            .into());
197        }
198        let split_column = self.split_column(&options);
199        let row_stream = if options.backfill_as_even_splits
200            && is_supported_even_split_data_type(&split_column.data_type)
201        {
202            // For certain types, use evenly-sized partition to optimize performance.
203            tracing::info!(?self.schema_table_name, ?self.rw_schema, ?self.pk_indices, ?split_column, "Get parallel cdc table snapshot even splits.");
204            self.as_even_splits(options)
205        } else {
206            tracing::info!(?self.schema_table_name, ?self.rw_schema, ?self.pk_indices, ?split_column, "Get parallel cdc table snapshot uneven splits.");
207            self.as_uneven_splits(options)
208        };
209        pin_mut!(row_stream);
210        #[for_await]
211        for row in row_stream {
212            let row = row?;
213            yield row;
214        }
215    }
216
217    fn split_snapshot_read(
218        &self,
219        table_name: SchemaTableName,
220        left: OwnedRow,
221        right: OwnedRow,
222        split_columns: Vec<Field>,
223    ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
224        assert_eq!(table_name, self.schema_table_name);
225        self.split_snapshot_read_inner(table_name, left, right, split_columns)
226    }
227}
228
229impl PostgresExternalTableReader {
230    pub async fn new(
231        config: ExternalTableConfig,
232        rw_schema: Schema,
233        pk_indices: Vec<usize>,
234        schema_table_name: SchemaTableName,
235    ) -> ConnectorResult<Self> {
236        tracing::info!(
237            ?rw_schema,
238            ?pk_indices,
239            "create postgres external table reader"
240        );
241        let client = create_pg_client(
242            &config.username,
243            &config.password,
244            &config.host,
245            &config.port,
246            &config.database,
247            &config.ssl_mode,
248            &config.ssl_root_cert,
249        )
250        .await?;
251        let field_names = rw_schema
252            .fields
253            .iter()
254            .map(|f| Self::quote_column(&f.name))
255            .join(",");
256
257        Ok(Self {
258            rw_schema,
259            field_names,
260            pk_indices,
261            client: tokio::sync::Mutex::new(client),
262            schema_table_name,
263        })
264    }
265
266    pub fn get_normalized_table_name(table_name: &SchemaTableName) -> String {
267        format!(
268            "\"{}\".\"{}\"",
269            table_name.schema_name, table_name.table_name
270        )
271    }
272
273    pub fn get_cdc_offset_parser() -> CdcOffsetParseFunc {
274        Box::new(move |offset| {
275            Ok(CdcOffset::Postgres(PostgresOffset::parse_debezium_offset(
276                offset,
277            )?))
278        })
279    }
280
281    #[try_stream(boxed, ok = OwnedRow, error = ConnectorError)]
282    async fn snapshot_read_inner(
283        &self,
284        table_name: SchemaTableName,
285        start_pk_row: Option<OwnedRow>,
286        primary_keys: Vec<String>,
287        scan_limit: u32,
288    ) {
289        let order_key = Self::get_order_key(&primary_keys);
290        let client = self.client.lock().await;
291        client.execute("set time zone '+00:00'", &[]).await?;
292
293        let stream = match start_pk_row {
294            Some(ref pk_row) => {
295                // prepare the scan statement, since we may need to convert the RW data type to postgres data type
296                // e.g. varchar to uuid
297                let prepared_scan_stmt = {
298                    let primary_keys = self
299                        .pk_indices
300                        .iter()
301                        .map(|i| self.rw_schema.fields[*i].name.clone())
302                        .collect_vec();
303
304                    let order_key = Self::get_order_key(&primary_keys);
305                    let scan_sql = format!(
306                        "SELECT {} FROM {} WHERE {} ORDER BY {} LIMIT {scan_limit}",
307                        self.field_names,
308                        Self::get_normalized_table_name(&table_name),
309                        Self::filter_expression(&primary_keys),
310                        order_key,
311                    );
312                    client.prepare(&scan_sql).await?
313                };
314
315                let params: Vec<Option<ScalarAdapter>> = pk_row
316                    .iter()
317                    .zip_eq_fast(prepared_scan_stmt.params())
318                    .map(|(datum, ty)| {
319                        datum
320                            .map(|scalar| ScalarAdapter::from_scalar(scalar, ty))
321                            .transpose()
322                    })
323                    .try_collect()?;
324
325                client.query_raw(&prepared_scan_stmt, &params).await?
326            }
327            None => {
328                let sql = format!(
329                    "SELECT {} FROM {} ORDER BY {} LIMIT {scan_limit}",
330                    self.field_names,
331                    Self::get_normalized_table_name(&table_name),
332                    order_key,
333                );
334                let params: Vec<Option<ScalarAdapter>> = vec![];
335                client.query_raw(&sql, &params).await?
336            }
337        };
338
339        let row_stream = stream.map(|row| {
340            let row = row?;
341            Ok::<_, crate::error::ConnectorError>(postgres_row_to_owned_row(row, &self.rw_schema))
342        });
343
344        pin_mut!(row_stream);
345        #[for_await]
346        for row in row_stream {
347            let row = row?;
348            yield row;
349        }
350    }
351
352    // row filter expression: (v1, v2, v3) > ($1, $2, $3)
353    fn filter_expression(columns: &[String]) -> String {
354        let mut col_expr = String::new();
355        let mut arg_expr = String::new();
356        for (i, column) in columns.iter().enumerate() {
357            if i > 0 {
358                col_expr.push_str(", ");
359                arg_expr.push_str(", ");
360            }
361            col_expr.push_str(&Self::quote_column(column));
362            arg_expr.push_str(format!("${}", i + 1).as_str());
363        }
364        format!("({}) > ({})", col_expr, arg_expr)
365    }
366
367    // row filter expression: (v1, v2, v3) >= ($1, $2, $3) AND (v1, v2, v3) < ($1, $2, $3)
368    fn split_filter_expression(
369        columns: &[String],
370        is_first_split: bool,
371        is_last_split: bool,
372    ) -> String {
373        let mut left_col_expr = String::new();
374        let mut left_arg_expr = String::new();
375        let mut right_col_expr = String::new();
376        let mut right_arg_expr = String::new();
377        let mut c = 1;
378        if !is_first_split {
379            for (i, column) in columns.iter().enumerate() {
380                if i > 0 {
381                    left_col_expr.push_str(", ");
382                    left_arg_expr.push_str(", ");
383                }
384                left_col_expr.push_str(&Self::quote_column(column));
385                left_arg_expr.push_str(format!("${}", c).as_str());
386                c += 1;
387            }
388        }
389        if !is_last_split {
390            for (i, column) in columns.iter().enumerate() {
391                if i > 0 {
392                    right_col_expr.push_str(", ");
393                    right_arg_expr.push_str(", ");
394                }
395                right_col_expr.push_str(&Self::quote_column(column));
396                right_arg_expr.push_str(format!("${}", c).as_str());
397                c += 1;
398            }
399        }
400        if is_first_split && is_last_split {
401            "1 = 1".to_owned()
402        } else if is_first_split {
403            format!("({}) < ({})", right_col_expr, right_arg_expr,)
404        } else if is_last_split {
405            format!("({}) >= ({})", left_col_expr, left_arg_expr,)
406        } else {
407            format!(
408                "({}) >= ({}) AND ({}) < ({})",
409                left_col_expr, left_arg_expr, right_col_expr, right_arg_expr,
410            )
411        }
412    }
413
414    fn get_order_key(primary_keys: &Vec<String>) -> String {
415        primary_keys
416            .iter()
417            .map(|col| Self::quote_column(col))
418            .join(",")
419    }
420
421    fn quote_column(column: &str) -> String {
422        format!("\"{}\"", column)
423    }
424
425    async fn min_and_max(
426        &self,
427        split_column: &Field,
428    ) -> ConnectorResult<Option<(ScalarImpl, ScalarImpl)>> {
429        let sql = format!(
430            "SELECT MIN({}), MAX({}) FROM {}",
431            split_column.name,
432            split_column.name,
433            Self::get_normalized_table_name(&self.schema_table_name),
434        );
435        let client = self.client.lock().await;
436        let rows = client.query(&sql, &[]).await?;
437        if rows.is_empty() {
438            Ok(None)
439        } else {
440            let row = &rows[0];
441            let min =
442                postgres_cell_to_scalar_impl(row, &split_column.data_type, 0, &split_column.name);
443            let max =
444                postgres_cell_to_scalar_impl(row, &split_column.data_type, 1, &split_column.name);
445            match (min, max) {
446                (Some(min), Some(max)) => Ok(Some((min, max))),
447                _ => Ok(None),
448            }
449        }
450    }
451
452    async fn next_split_right_bound_exclusive(
453        &self,
454        left_value: &ScalarImpl,
455        max_value: &ScalarImpl,
456        max_split_size: u64,
457        split_column: &Field,
458    ) -> ConnectorResult<Option<Datum>> {
459        let sql = format!(
460            "WITH t as (SELECT {} FROM {} WHERE {} >= $1 ORDER BY {} ASC LIMIT {}) SELECT CASE WHEN MAX({}) < $2 THEN MAX({}) ELSE NULL END FROM t",
461            Self::quote_column(&split_column.name),
462            Self::get_normalized_table_name(&self.schema_table_name),
463            Self::quote_column(&split_column.name),
464            Self::quote_column(&split_column.name),
465            max_split_size,
466            Self::quote_column(&split_column.name),
467            Self::quote_column(&split_column.name),
468        );
469        let client = self.client.lock().await;
470        let prepared_stmt = client.prepare(&sql).await?;
471        let params: Vec<Option<ScalarAdapter>> = vec![
472            Some(ScalarAdapter::from_scalar(
473                left_value.as_scalar_ref_impl(),
474                &prepared_stmt.params()[0],
475            )?),
476            Some(ScalarAdapter::from_scalar(
477                max_value.as_scalar_ref_impl(),
478                &prepared_stmt.params()[1],
479            )?),
480        ];
481        let stream = client.query_raw(&prepared_stmt, &params).await?;
482        let datum_stream = stream.map(|row| {
483            let row = row?;
484            Ok::<_, ConnectorError>(postgres_cell_to_scalar_impl(
485                &row,
486                &split_column.data_type,
487                0,
488                &split_column.name,
489            ))
490        });
491        pin_mut!(datum_stream);
492        #[for_await]
493        for datum in datum_stream {
494            let right = datum?;
495            return Ok(Some(right.to_owned_datum()));
496        }
497        Ok(None)
498    }
499
500    async fn next_greater_bound(
501        &self,
502        start_offset: &ScalarImpl,
503        max_value: &ScalarImpl,
504        split_column: &Field,
505    ) -> ConnectorResult<Option<Datum>> {
506        let sql = format!(
507            "SELECT MIN({}) FROM {} WHERE {} > $1 AND {} <$2",
508            Self::quote_column(&split_column.name),
509            Self::get_normalized_table_name(&self.schema_table_name),
510            Self::quote_column(&split_column.name),
511            Self::quote_column(&split_column.name),
512        );
513        let client = self.client.lock().await;
514        let prepared_stmt = client.prepare(&sql).await?;
515        let params: Vec<Option<ScalarAdapter>> = vec![
516            Some(ScalarAdapter::from_scalar(
517                start_offset.as_scalar_ref_impl(),
518                &prepared_stmt.params()[0],
519            )?),
520            Some(ScalarAdapter::from_scalar(
521                max_value.as_scalar_ref_impl(),
522                &prepared_stmt.params()[1],
523            )?),
524        ];
525        let stream = client.query_raw(&prepared_stmt, &params).await?;
526        let datum_stream = stream.map(|row| {
527            let row = row?;
528            Ok::<_, ConnectorError>(postgres_cell_to_scalar_impl(
529                &row,
530                &split_column.data_type,
531                0,
532                &split_column.name,
533            ))
534        });
535        pin_mut!(datum_stream);
536        #[for_await]
537        for datum in datum_stream {
538            let right = datum?;
539            return Ok(Some(right));
540        }
541        Ok(None)
542    }
543
544    #[try_stream(boxed, ok = OwnedRow, error = ConnectorError)]
545    async fn split_snapshot_read_inner(
546        &self,
547        table_name: SchemaTableName,
548        left: OwnedRow,
549        right: OwnedRow,
550        split_columns: Vec<Field>,
551    ) {
552        assert_eq!(
553            split_columns.len(),
554            1,
555            "multiple split columns is not supported yet"
556        );
557        assert_eq!(left.len(), 1, "multiple split columns is not supported yet");
558        assert_eq!(
559            right.len(),
560            1,
561            "multiple split columns is not supported yet"
562        );
563        let is_first_split = left[0].is_none();
564        let is_last_split = right[0].is_none();
565        let split_column_names = split_columns.iter().map(|c| c.name.clone()).collect_vec();
566        let client = self.client.lock().await;
567        client.execute("set time zone '+00:00'", &[]).await?;
568        // prepare the scan statement, since we may need to convert the RW data type to postgres data type
569        // e.g. varchar to uuid
570        let prepared_scan_stmt = {
571            let scan_sql = format!(
572                "SELECT {} FROM {} WHERE {}",
573                self.field_names,
574                Self::get_normalized_table_name(&table_name),
575                Self::split_filter_expression(&split_column_names, is_first_split, is_last_split),
576            );
577            client.prepare(&scan_sql).await?
578        };
579
580        let mut params: Vec<Option<ScalarAdapter>> = vec![];
581        if !is_first_split {
582            let left_params: Vec<Option<ScalarAdapter>> = left
583                .iter()
584                .zip_eq_fast(prepared_scan_stmt.params().iter().take(left.len()))
585                .map(|(datum, ty)| {
586                    datum
587                        .map(|scalar| ScalarAdapter::from_scalar(scalar, ty))
588                        .transpose()
589                })
590                .try_collect()?;
591            params.extend(left_params);
592        }
593        if !is_last_split {
594            let right_params: Vec<Option<ScalarAdapter>> = right
595                .iter()
596                .zip_eq_fast(prepared_scan_stmt.params().iter().skip(params.len()))
597                .map(|(datum, ty)| {
598                    datum
599                        .map(|scalar| ScalarAdapter::from_scalar(scalar, ty))
600                        .transpose()
601                })
602                .try_collect()?;
603            params.extend(right_params);
604        }
605
606        let stream = client.query_raw(&prepared_scan_stmt, &params).await?;
607        let row_stream = stream.map(|row| {
608            let row = row?;
609            Ok::<_, crate::error::ConnectorError>(postgres_row_to_owned_row(row, &self.rw_schema))
610        });
611
612        pin_mut!(row_stream);
613        #[for_await]
614        for row in row_stream {
615            let row = row?;
616            yield row;
617        }
618    }
619
620    #[try_stream(boxed, ok = CdcTableSnapshotSplit, error = ConnectorError)]
621    async fn as_uneven_splits(&self, options: CdcTableSnapshotSplitOption) {
622        let split_column = self.split_column(&options);
623        let mut split_id = CDC_TABLE_SPLIT_ID_START;
624        let Some((min_value, max_value)) = self.min_and_max(&split_column).await? else {
625            let left_bound_row = OwnedRow::new(vec![None]);
626            let right_bound_row = OwnedRow::new(vec![None]);
627            let split = CdcTableSnapshotSplit {
628                split_id,
629                left_bound_inclusive: left_bound_row,
630                right_bound_exclusive: right_bound_row,
631            };
632            yield split;
633            return Ok(());
634        };
635        // left bound will never be NULL value.
636        let mut next_left_bound_inclusive = min_value.clone();
637        loop {
638            let left_bound_inclusive: Datum = if next_left_bound_inclusive == min_value {
639                None
640            } else {
641                Some(next_left_bound_inclusive.clone())
642            };
643            let right_bound_exclusive;
644            let mut next_right = self
645                .next_split_right_bound_exclusive(
646                    &next_left_bound_inclusive,
647                    &max_value,
648                    options.backfill_num_rows_per_split,
649                    &split_column,
650                )
651                .await?;
652            if let Some(Some(ref inner)) = next_right
653                && *inner == next_left_bound_inclusive
654            {
655                next_right = self
656                    .next_greater_bound(&next_left_bound_inclusive, &max_value, &split_column)
657                    .await?;
658            }
659            if let Some(next_right) = next_right {
660                match next_right {
661                    None => {
662                        // NULL found.
663                        right_bound_exclusive = None;
664                    }
665                    Some(next_right) => {
666                        next_left_bound_inclusive = next_right.to_owned();
667                        right_bound_exclusive = Some(next_right);
668                    }
669                }
670            } else {
671                // Not found.
672                right_bound_exclusive = None;
673            };
674            let is_completed = right_bound_exclusive.is_none();
675            if is_completed && left_bound_inclusive.is_none() {
676                assert_eq!(split_id, 1);
677            }
678            tracing::info!(
679                split_id,
680                ?left_bound_inclusive,
681                ?right_bound_exclusive,
682                "New CDC table snapshot split."
683            );
684            let left_bound_row = OwnedRow::new(vec![left_bound_inclusive]);
685            let right_bound_row = OwnedRow::new(vec![right_bound_exclusive]);
686            let split = CdcTableSnapshotSplit {
687                split_id,
688                left_bound_inclusive: left_bound_row,
689                right_bound_exclusive: right_bound_row,
690            };
691            try_increase_split_id(&mut split_id)?;
692            yield split;
693            if is_completed {
694                break;
695            }
696        }
697    }
698
699    #[try_stream(boxed, ok = CdcTableSnapshotSplit, error = ConnectorError)]
700    async fn as_even_splits(&self, options: CdcTableSnapshotSplitOption) {
701        let split_column = self.split_column(&options);
702        let mut split_id = 1;
703        let Some((min_value, max_value)) = self.min_and_max(&split_column).await? else {
704            let left_bound_row = OwnedRow::new(vec![None]);
705            let right_bound_row = OwnedRow::new(vec![None]);
706            let split = CdcTableSnapshotSplit {
707                split_id,
708                left_bound_inclusive: left_bound_row,
709                right_bound_exclusive: right_bound_row,
710            };
711            yield split;
712            return Ok(());
713        };
714        let min_value = min_value.as_integral();
715        let max_value = max_value.as_integral();
716        let saturated_split_max_size = options
717            .backfill_num_rows_per_split
718            .try_into()
719            .unwrap_or(i64::MAX);
720        let mut left = None;
721        let mut right = Some(min_value.saturating_add(saturated_split_max_size));
722        loop {
723            let mut is_completed = false;
724            if right.as_ref().map(|r| *r >= max_value).unwrap_or(true) {
725                right = None;
726                is_completed = true;
727            }
728            let split = CdcTableSnapshotSplit {
729                split_id,
730                left_bound_inclusive: OwnedRow::new(vec![
731                    left.map(|l| to_int_scalar(l, &split_column.data_type)),
732                ]),
733                right_bound_exclusive: OwnedRow::new(vec![
734                    right.map(|r| to_int_scalar(r, &split_column.data_type)),
735                ]),
736            };
737            try_increase_split_id(&mut split_id)?;
738            yield split;
739            if is_completed {
740                break;
741            }
742            left = right;
743            right = left.map(|l| l.saturating_add(saturated_split_max_size));
744        }
745    }
746
747    fn split_column(&self, options: &CdcTableSnapshotSplitOption) -> Field {
748        self.rw_schema.fields[self.pk_indices[options.backfill_split_pk_column_index as usize]]
749            .clone()
750    }
751}
752
753fn to_int_scalar(i: i64, data_type: &DataType) -> ScalarImpl {
754    match data_type {
755        DataType::Int16 => ScalarImpl::Int16(i.try_into().unwrap()),
756        DataType::Int32 => ScalarImpl::Int32(i.try_into().unwrap()),
757        DataType::Int64 => ScalarImpl::Int64(i),
758        _ => {
759            panic!("Can't convert int {} to ScalarImpl::{}", i, data_type)
760        }
761    }
762}
763
764fn try_increase_split_id(split_id: &mut i64) -> ConnectorResult<()> {
765    match split_id.checked_add(1) {
766        Some(s) => {
767            *split_id = s;
768            Ok(())
769        }
770        None => Err(anyhow::anyhow!("too many CDC snapshot splits").into()),
771    }
772}
773
774/// Use the first column of primary keys to split table.
775fn is_supported_even_split_data_type(data_type: &DataType) -> bool {
776    matches!(
777        data_type,
778        DataType::Int16 | DataType::Int32 | DataType::Int64
779    )
780}
781
782pub fn type_name_to_pg_type(ty_name: &str) -> Option<PgType> {
783    let ty_name_lower = ty_name.to_lowercase();
784    // Handle array types (prefixed with _)
785    if let Some(base_type) = ty_name_lower.strip_prefix('_') {
786        match base_type {
787            "int2" => Some(PgType::INT2_ARRAY),
788            "int4" => Some(PgType::INT4_ARRAY),
789            "int8" => Some(PgType::INT8_ARRAY),
790            "bit" => Some(PgType::BIT_ARRAY),
791            "float4" => Some(PgType::FLOAT4_ARRAY),
792            "float8" => Some(PgType::FLOAT8_ARRAY),
793            "numeric" => Some(PgType::NUMERIC_ARRAY),
794            "bool" => Some(PgType::BOOL_ARRAY),
795            "xml" | "macaddr" | "macaddr8" | "cidr" | "inet" | "int4range" | "int8range"
796            | "numrange" | "tsrange" | "tstzrange" | "daterange" | "citext" => {
797                Some(PgType::VARCHAR_ARRAY)
798            }
799            "varchar" => Some(PgType::VARCHAR_ARRAY),
800            "text" => Some(PgType::TEXT_ARRAY),
801            "bytea" => Some(PgType::BYTEA_ARRAY),
802            "date" => Some(PgType::DATE_ARRAY),
803            "time" => Some(PgType::TIME_ARRAY),
804            "timetz" => Some(PgType::TIMETZ_ARRAY),
805            "timestamp" => Some(PgType::TIMESTAMP_ARRAY),
806            "timestamptz" => Some(PgType::TIMESTAMPTZ_ARRAY),
807            "interval" => Some(PgType::INTERVAL_ARRAY),
808            "json" => Some(PgType::JSON_ARRAY),
809            "jsonb" => Some(PgType::JSONB_ARRAY),
810            "uuid" => Some(PgType::UUID_ARRAY),
811            "point" => Some(PgType::POINT_ARRAY),
812            "oid" => Some(PgType::OID_ARRAY),
813            "money" => Some(PgType::MONEY_ARRAY),
814            _ => None,
815        }
816    } else {
817        // Handle non-array types
818        match ty_name_lower.as_str() {
819            "int2" => Some(PgType::INT2),
820            "bit" => Some(PgType::BIT),
821            "int" | "int4" => Some(PgType::INT4),
822            "int8" => Some(PgType::INT8),
823            "float4" => Some(PgType::FLOAT4),
824            "float8" => Some(PgType::FLOAT8),
825            "numeric" => Some(PgType::NUMERIC),
826            "money" => Some(PgType::MONEY),
827            "boolean" | "bool" => Some(PgType::BOOL),
828            "inet" | "xml" | "varchar" | "character varying" | "int4range" | "int8range"
829            | "numrange" | "tsrange" | "tstzrange" | "daterange" | "macaddr" | "macaddr8"
830            | "cidr" => Some(PgType::VARCHAR),
831            "char" | "character" | "bpchar" => Some(PgType::BPCHAR),
832            "citext" | "text" => Some(PgType::TEXT),
833            "bytea" => Some(PgType::BYTEA),
834            "date" => Some(PgType::DATE),
835            "time" => Some(PgType::TIME),
836            "timetz" => Some(PgType::TIMETZ),
837            "timestamp" => Some(PgType::TIMESTAMP),
838            "timestamptz" => Some(PgType::TIMESTAMPTZ),
839            "interval" => Some(PgType::INTERVAL),
840            "json" => Some(PgType::JSON),
841            "jsonb" => Some(PgType::JSONB),
842            "uuid" => Some(PgType::UUID),
843            "point" => Some(PgType::POINT),
844            "oid" => Some(PgType::OID),
845            _ => None,
846        }
847    }
848}
849
850pub fn pg_type_to_rw_type(pg_type: &PgType) -> ConnectorResult<DataType> {
851    let data_type = match *pg_type {
852        PgType::BOOL => DataType::Boolean,
853        PgType::BIT => DataType::Boolean,
854        PgType::INT2 => DataType::Int16,
855        PgType::INT4 => DataType::Int32,
856        PgType::INT8 => DataType::Int64,
857        PgType::FLOAT4 => DataType::Float32,
858        PgType::FLOAT8 => DataType::Float64,
859        PgType::NUMERIC | PgType::MONEY => DataType::Decimal,
860        PgType::DATE => DataType::Date,
861        PgType::TIME => DataType::Time,
862        PgType::TIMETZ => DataType::Time,
863        PgType::POINT => DataType::Struct(risingwave_common::types::StructType::new(vec![
864            ("x", DataType::Float32),
865            ("y", DataType::Float32),
866        ])),
867        PgType::TIMESTAMP => DataType::Timestamp,
868        PgType::TIMESTAMPTZ => DataType::Timestamptz,
869        PgType::INTERVAL => DataType::Interval,
870        PgType::VARCHAR | PgType::TEXT | PgType::BPCHAR | PgType::UUID => DataType::Varchar,
871        PgType::BYTEA => DataType::Bytea,
872        PgType::JSON | PgType::JSONB => DataType::Jsonb,
873        // Array types
874        PgType::BOOL_ARRAY => DataType::List(Box::new(DataType::Boolean)),
875        PgType::BIT_ARRAY => DataType::List(Box::new(DataType::Boolean)),
876        PgType::INT2_ARRAY => DataType::List(Box::new(DataType::Int16)),
877        PgType::INT4_ARRAY => DataType::List(Box::new(DataType::Int32)),
878        PgType::INT8_ARRAY => DataType::List(Box::new(DataType::Int64)),
879        PgType::FLOAT4_ARRAY => DataType::List(Box::new(DataType::Float32)),
880        PgType::FLOAT8_ARRAY => DataType::List(Box::new(DataType::Float64)),
881        PgType::NUMERIC_ARRAY => DataType::List(Box::new(DataType::Decimal)),
882        PgType::VARCHAR_ARRAY => DataType::List(Box::new(DataType::Varchar)),
883        PgType::TEXT_ARRAY => DataType::List(Box::new(DataType::Varchar)),
884        PgType::BYTEA_ARRAY => DataType::List(Box::new(DataType::Bytea)),
885        PgType::DATE_ARRAY => DataType::List(Box::new(DataType::Date)),
886        PgType::TIME_ARRAY => DataType::List(Box::new(DataType::Time)),
887        PgType::TIMESTAMP_ARRAY => DataType::List(Box::new(DataType::Timestamp)),
888        PgType::TIMESTAMPTZ_ARRAY => DataType::List(Box::new(DataType::Timestamptz)),
889        PgType::INTERVAL_ARRAY => DataType::List(Box::new(DataType::Interval)),
890        PgType::JSON_ARRAY => DataType::List(Box::new(DataType::Jsonb)),
891        PgType::JSONB_ARRAY => DataType::List(Box::new(DataType::Jsonb)),
892        PgType::UUID_ARRAY => DataType::List(Box::new(DataType::Varchar)),
893        PgType::OID => DataType::Int64,
894        PgType::OID_ARRAY => DataType::List(Box::new(DataType::Int64)),
895        PgType::MONEY_ARRAY => DataType::List(Box::new(DataType::Decimal)),
896        PgType::POINT_ARRAY => DataType::List(Box::new(DataType::Struct(
897            risingwave_common::types::StructType::new(vec![
898                ("x", DataType::Float32),
899                ("y", DataType::Float32),
900            ]),
901        ))),
902        _ => {
903            return Err(anyhow::anyhow!("unsupported postgres type: {}", pg_type).into());
904        }
905    };
906    Ok(data_type)
907}
908
909#[cfg(test)]
910mod tests {
911    use std::cmp::Ordering;
912    use std::collections::HashMap;
913
914    use futures::pin_mut;
915    use futures_async_stream::for_await;
916    use maplit::{convert_args, hashmap};
917    use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
918    use risingwave_common::row::OwnedRow;
919    use risingwave_common::types::{DataType, ScalarImpl};
920
921    use crate::connector_common::PostgresExternalTable;
922    use crate::source::cdc::external::postgres::{PostgresExternalTableReader, PostgresOffset};
923    use crate::source::cdc::external::{ExternalTableConfig, ExternalTableReader, SchemaTableName};
924
925    #[ignore]
926    #[tokio::test]
927    async fn test_postgres_schema() {
928        let config = ExternalTableConfig {
929            connector: "postgres-cdc".to_owned(),
930            host: "localhost".to_owned(),
931            port: "8432".to_owned(),
932            username: "myuser".to_owned(),
933            password: "123456".to_owned(),
934            database: "mydb".to_owned(),
935            schema: "public".to_owned(),
936            table: "mytest".to_owned(),
937            ssl_mode: Default::default(),
938            ssl_root_cert: None,
939            encrypt: "false".to_owned(),
940        };
941
942        let table = PostgresExternalTable::connect(
943            &config.username,
944            &config.password,
945            &config.host,
946            config.port.parse::<u16>().unwrap(),
947            &config.database,
948            &config.schema,
949            &config.table,
950            &config.ssl_mode,
951            &config.ssl_root_cert,
952            false,
953        )
954        .await
955        .unwrap();
956
957        println!("columns: {:?}", &table.column_descs());
958        println!("primary keys: {:?}", &table.pk_names());
959    }
960
961    #[test]
962    fn test_postgres_offset() {
963        let off1 = PostgresOffset {
964            txid: 4,
965            lsn: 2,
966            ..Default::default()
967        };
968        let off2 = PostgresOffset {
969            txid: 1,
970            lsn: 3,
971            ..Default::default()
972        };
973        let off3 = PostgresOffset {
974            txid: 5,
975            lsn: 1,
976            ..Default::default()
977        };
978
979        assert!(off1 < off2);
980        assert!(off3 < off1);
981        assert!(off2 > off3);
982    }
983
984    #[test]
985    fn test_postgres_offset_partial_ord_with_lsn_commit() {
986        // Test comparison with both lsn_commit and lsn_proc fields
987        let off1 = PostgresOffset {
988            txid: 1,
989            lsn: 100,
990            lsn_commit: Some(200),
991            lsn_proc: Some(150),
992        };
993        let off2 = PostgresOffset {
994            txid: 2,
995            lsn: 300,
996            lsn_commit: Some(250),
997            lsn_proc: Some(200),
998        };
999
1000        // Should compare using lsn_commit first when both have both fields
1001        assert!(off1 < off2);
1002
1003        // Test with same lsn_commit but different lsn_proc
1004        let off3 = PostgresOffset {
1005            txid: 3,
1006            lsn: 500,
1007            lsn_commit: Some(200), // same as off1
1008            lsn_proc: Some(160),   // higher than off1
1009        };
1010
1011        // Should compare lsn_proc when lsn_commit is equal
1012        assert!(off1 < off3);
1013
1014        // Test with missing lsn_proc - should fall back to lsn comparison
1015        let off4 = PostgresOffset {
1016            txid: 4,
1017            lsn: 400,
1018            lsn_commit: Some(100), // lower than off1's lsn_commit
1019            lsn_proc: None,        // missing lsn_proc
1020        };
1021
1022        // Should fall back to lsn comparison (off1.lsn=100 < off4.lsn=400)
1023        assert!(off1 < off4);
1024
1025        // Test with missing lsn_commit - should fall back to lsn comparison
1026        let off5 = PostgresOffset {
1027            txid: 5,
1028            lsn: 50,             // lower than off1.lsn
1029            lsn_commit: None,    // missing lsn_commit
1030            lsn_proc: Some(300), // higher than off1's lsn_proc
1031        };
1032
1033        // Should fall back to lsn comparison (off5.lsn=50 < off1.lsn=100)
1034        assert!(off5 < off1);
1035
1036        // Additional test cases: equal lsn_commit values with different lsn_proc
1037        let off6 = PostgresOffset {
1038            txid: 6,
1039            lsn: 600,
1040            lsn_commit: Some(500),
1041            lsn_proc: Some(300),
1042        };
1043        let off7 = PostgresOffset {
1044            txid: 7,
1045            lsn: 700,
1046            lsn_commit: Some(500), // same as off6
1047            lsn_proc: Some(400),   // higher than off6
1048        };
1049
1050        // Should compare lsn_proc since lsn_commit is equal
1051        assert!(off6 < off7);
1052
1053        // Test reverse order
1054        let off8 = PostgresOffset {
1055            txid: 8,
1056            lsn: 800,
1057            lsn_commit: Some(500), // same as others
1058            lsn_proc: Some(200),   // lower than off6
1059        };
1060
1061        assert!(off8 < off6);
1062        assert!(off8 < off7);
1063
1064        // Test equal lsn_commit and lsn_proc
1065        let off9 = PostgresOffset {
1066            txid: 9,
1067            lsn: 900,
1068            lsn_commit: Some(500), // same as off6
1069            lsn_proc: Some(300),   // same as off6
1070        };
1071
1072        // Should be equal
1073        assert_eq!(off6.partial_cmp(&off9), Some(Ordering::Equal));
1074    }
1075
1076    #[test]
1077    fn test_debezium_offset_parsing() {
1078        // Test parsing with all required fields present
1079        let debezium_offset_with_fields = r#"{
1080            "sourcePartition": {"server": "RW_CDC_1004"},
1081            "sourceOffset": {
1082                "last_snapshot_record": false,
1083                "lsn": 29973552,
1084                "txId": 1046,
1085                "ts_usec": 1670826189008456,
1086                "snapshot": true,
1087                "lsn_commit": 29973600,
1088                "lsn_proc": 29973580
1089            },
1090            "isHeartbeat": false
1091        }"#;
1092
1093        let offset = PostgresOffset::parse_debezium_offset(debezium_offset_with_fields).unwrap();
1094        assert_eq!(offset.txid, 1046);
1095        assert_eq!(offset.lsn, 29973552);
1096        assert_eq!(offset.lsn_commit, Some(29973600));
1097        assert_eq!(offset.lsn_proc, Some(29973580));
1098
1099        // Test parsing should fail when required fields are missing
1100        let debezium_offset_missing_fields = r#"{
1101            "sourcePartition": {"server": "RW_CDC_1004"},
1102            "sourceOffset": {
1103                "last_snapshot_record": false,
1104                "lsn": 29973552,
1105                "txId": 1046,
1106                "ts_usec": 1670826189008456,
1107                "snapshot": true
1108            },
1109            "isHeartbeat": false
1110        }"#;
1111
1112        let result = PostgresOffset::parse_debezium_offset(debezium_offset_missing_fields);
1113        assert!(result.is_err());
1114        let error_msg = result.unwrap_err().to_string();
1115        assert!(error_msg.contains("invalid postgres lsn_proc"));
1116    }
1117
1118    #[test]
1119    fn test_filter_expression() {
1120        let cols = vec!["v1".to_owned()];
1121        let expr = PostgresExternalTableReader::filter_expression(&cols);
1122        assert_eq!(expr, "(\"v1\") > ($1)");
1123
1124        let cols = vec!["v1".to_owned(), "v2".to_owned()];
1125        let expr = PostgresExternalTableReader::filter_expression(&cols);
1126        assert_eq!(expr, "(\"v1\", \"v2\") > ($1, $2)");
1127
1128        let cols = vec!["v1".to_owned(), "v2".to_owned(), "v3".to_owned()];
1129        let expr = PostgresExternalTableReader::filter_expression(&cols);
1130        assert_eq!(expr, "(\"v1\", \"v2\", \"v3\") > ($1, $2, $3)");
1131    }
1132
1133    #[test]
1134    fn test_split_filter_expression() {
1135        let cols = vec!["v1".to_owned()];
1136        let expr = PostgresExternalTableReader::split_filter_expression(&cols, true, true);
1137        assert_eq!(expr, "1 = 1");
1138
1139        let expr = PostgresExternalTableReader::split_filter_expression(&cols, true, false);
1140        assert_eq!(expr, "(\"v1\") < ($1)");
1141
1142        let expr = PostgresExternalTableReader::split_filter_expression(&cols, false, true);
1143        assert_eq!(expr, "(\"v1\") >= ($1)");
1144
1145        let expr = PostgresExternalTableReader::split_filter_expression(&cols, false, false);
1146        assert_eq!(expr, "(\"v1\") >= ($1) AND (\"v1\") < ($2)");
1147    }
1148
1149    // manual test
1150    #[ignore]
1151    #[tokio::test]
1152    async fn test_pg_table_reader() {
1153        let columns = vec![
1154            ColumnDesc::named("v1", ColumnId::new(1), DataType::Int32),
1155            ColumnDesc::named("v2", ColumnId::new(2), DataType::Varchar),
1156            ColumnDesc::named("v3", ColumnId::new(3), DataType::Decimal),
1157            ColumnDesc::named("v4", ColumnId::new(4), DataType::Date),
1158        ];
1159        let rw_schema = Schema {
1160            fields: columns.iter().map(Field::from).collect(),
1161        };
1162
1163        let props: HashMap<String, String> = convert_args!(hashmap!(
1164                "hostname" => "localhost",
1165                "port" => "8432",
1166                "username" => "myuser",
1167                "password" => "123456",
1168                "database.name" => "mydb",
1169                "schema.name" => "public",
1170                "table.name" => "t1"));
1171
1172        let config =
1173            serde_json::from_value::<ExternalTableConfig>(serde_json::to_value(props).unwrap())
1174                .unwrap();
1175        let schema_table_name = SchemaTableName {
1176            schema_name: "public".to_owned(),
1177            table_name: "t1".to_owned(),
1178        };
1179        let reader = PostgresExternalTableReader::new(
1180            config,
1181            rw_schema,
1182            vec![0, 1],
1183            schema_table_name.clone(),
1184        )
1185        .await
1186        .unwrap();
1187
1188        let offset = reader.current_cdc_offset().await.unwrap();
1189        println!("CdcOffset: {:?}", offset);
1190
1191        let start_pk = OwnedRow::new(vec![Some(ScalarImpl::from(3)), Some(ScalarImpl::from("c"))]);
1192        let stream = reader.snapshot_read(
1193            schema_table_name,
1194            Some(start_pk),
1195            vec!["v1".to_owned(), "v2".to_owned()],
1196            1000,
1197        );
1198
1199        pin_mut!(stream);
1200        #[for_await]
1201        for row in stream {
1202            println!("OwnedRow: {:?}", row);
1203        }
1204    }
1205}