risingwave_connector/source/cdc/external/
postgres.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::sync::LazyLock;
17
18use anyhow::Context;
19use futures::stream::BoxStream;
20use futures::{StreamExt, pin_mut};
21use futures_async_stream::{for_await, try_stream};
22use itertools::Itertools;
23use risingwave_common::catalog::{Field, Schema};
24use risingwave_common::log::LogSuppressor;
25use risingwave_common::row::{OwnedRow, Row};
26use risingwave_common::types::{DataType, Datum, ScalarImpl, ToOwnedDatum};
27use risingwave_common::util::iter_util::ZipEqFast;
28use serde::{Deserialize, Serialize};
29use thiserror_ext::AsReport;
30use tokio_postgres::types::{PgLsn, Type as PgType};
31
32use crate::connector_common::create_pg_client;
33use crate::error::{ConnectorError, ConnectorResult};
34use crate::parser::scalar_adapter::ScalarAdapter;
35use crate::parser::{postgres_cell_to_scalar_impl, postgres_row_to_owned_row};
36use crate::source::CdcTableSnapshotSplit;
37use crate::source::cdc::external::{
38    CDC_TABLE_SPLIT_ID_START, CdcOffset, CdcOffsetParseFunc, CdcTableSnapshotSplitOption,
39    DebeziumOffset, ExternalTableConfig, ExternalTableReader, SchemaTableName,
40};
41
42#[derive(Debug, Clone, Default, Serialize, Deserialize)]
43pub struct PostgresOffset {
44    pub txid: i64,
45    // In postgres, an LSN is a 64-bit integer, representing a byte position in the write-ahead log stream.
46    // It is printed as two hexadecimal numbers of up to 8 digits each, separated by a slash; for example, 16/B374D848
47    pub lsn: u64,
48    // Additional LSN fields for improved tracking
49    #[serde(default)]
50    pub lsn_commit: Option<u64>,
51    #[serde(default)]
52    pub lsn_proc: Option<u64>,
53}
54
55impl PartialOrd for PostgresOffset {
56    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
57        Some(self.cmp(other))
58    }
59}
60
61impl Eq for PostgresOffset {}
62impl PartialEq for PostgresOffset {
63    fn eq(&self, other: &Self) -> bool {
64        match (
65            self.lsn_commit,
66            self.lsn_proc,
67            other.lsn_commit,
68            other.lsn_proc,
69        ) {
70            (_, Some(_), _, Some(_)) => {
71                self.lsn_commit == other.lsn_commit && self.lsn_proc == other.lsn_proc
72            }
73            _ => self.lsn == other.lsn,
74        }
75    }
76}
77
78// only compare the lsn field, prefer lsn_commit and lsn_proc if both available
79impl Ord for PostgresOffset {
80    fn cmp(&self, other: &Self) -> Ordering {
81        match (
82            self.lsn_commit,
83            self.lsn_proc,
84            other.lsn_commit,
85            other.lsn_proc,
86        ) {
87            (_, Some(self_proc), _, Some(other_proc)) => {
88                // if both have `lsn_commit` and `lsn_proc`, compare `lsn_commit` first, then `lsn_proc`
89                // if `lsn_commit` is None, fall back to `lsn_proc`
90                match self.lsn_commit.cmp(&other.lsn_commit) {
91                    Ordering::Equal => self_proc.cmp(&other_proc),
92                    other_result => other_result,
93                }
94            }
95            _ => {
96                // Fall back to lsn comparison when either lsn_commit or lsn_proc is missing
97                static LOG_SUPPRESSOR: LazyLock<LogSuppressor> =
98                    LazyLock::new(LogSuppressor::default);
99                if let Ok(suppressed_count) = LOG_SUPPRESSOR.check() {
100                    tracing::warn!(
101                        suppressed_count,
102                        self_lsn = self.lsn,
103                        other_lsn = other.lsn,
104                        "lsn_commit and lsn_proc are missing, fall back to lsn comparison"
105                    );
106                }
107                self.lsn.cmp(&other.lsn)
108            }
109        }
110    }
111}
112
113impl PostgresOffset {
114    pub fn parse_debezium_offset(offset: &str) -> ConnectorResult<Self> {
115        let dbz_offset: DebeziumOffset = serde_json::from_str(offset)
116            .with_context(|| format!("invalid upstream offset: {}", offset))?;
117
118        let lsn = dbz_offset
119            .source_offset
120            .lsn
121            .context("invalid postgres lsn")?;
122
123        // `lsn_commit` may not be present in the offset for the first tx.
124        let lsn_commit = dbz_offset.source_offset.lsn_commit;
125
126        let lsn_proc = dbz_offset
127            .source_offset
128            .lsn_proc
129            .context("invalid postgres lsn_proc")?;
130
131        Ok(Self {
132            txid: dbz_offset
133                .source_offset
134                .txid
135                .context("invalid postgres txid")?,
136            lsn,
137            lsn_commit,
138            lsn_proc: Some(lsn_proc),
139        })
140    }
141}
142
143pub struct PostgresExternalTableReader {
144    rw_schema: Schema,
145    field_names: String,
146    pk_indices: Vec<usize>,
147    client: tokio::sync::Mutex<tokio_postgres::Client>,
148    schema_table_name: SchemaTableName,
149}
150
151impl ExternalTableReader for PostgresExternalTableReader {
152    async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
153        let mut client = self.client.lock().await;
154        // start a transaction to read current lsn and txid
155        let trxn = client.transaction().await?;
156        let row = trxn.query_one("SELECT pg_current_wal_lsn()", &[]).await?;
157        let mut pg_offset = PostgresOffset::default();
158        let pg_lsn = row.get::<_, PgLsn>(0);
159        tracing::debug!("current lsn: {}", pg_lsn);
160        pg_offset.lsn = pg_lsn.into();
161
162        let txid_row = trxn.query_one("SELECT txid_current()", &[]).await?;
163        let txid: i64 = txid_row.get::<_, i64>(0);
164        pg_offset.txid = txid;
165
166        // commit the transaction
167        trxn.commit().await?;
168
169        Ok(CdcOffset::Postgres(pg_offset))
170    }
171
172    fn snapshot_read(
173        &self,
174        table_name: SchemaTableName,
175        start_pk: Option<OwnedRow>,
176        primary_keys: Vec<String>,
177        limit: u32,
178    ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
179        assert_eq!(table_name, self.schema_table_name);
180        self.snapshot_read_inner(table_name, start_pk, primary_keys, limit)
181    }
182
183    #[try_stream(boxed, ok = CdcTableSnapshotSplit, error = ConnectorError)]
184    async fn get_parallel_cdc_splits(&self, options: CdcTableSnapshotSplitOption) {
185        let backfill_num_rows_per_split = options.backfill_num_rows_per_split;
186        if backfill_num_rows_per_split == 0 {
187            return Err(anyhow::anyhow!(
188                "invalid backfill_num_rows_per_split, must be greater than 0"
189            )
190            .into());
191        }
192        if options.backfill_split_pk_column_index as usize >= self.pk_indices.len() {
193            return Err(anyhow::anyhow!(format!(
194                "invalid backfill_split_pk_column_index {}, out of bound",
195                options.backfill_split_pk_column_index
196            ))
197            .into());
198        }
199        let split_column = self.split_column(&options);
200        let row_stream = if options.backfill_as_even_splits
201            && is_supported_even_split_data_type(&split_column.data_type)
202        {
203            // For certain types, use evenly-sized partition to optimize performance.
204            tracing::info!(?self.schema_table_name, ?self.rw_schema, ?self.pk_indices, ?split_column, "Get parallel cdc table snapshot even splits.");
205            self.as_even_splits(options)
206        } else {
207            tracing::info!(?self.schema_table_name, ?self.rw_schema, ?self.pk_indices, ?split_column, "Get parallel cdc table snapshot uneven splits.");
208            self.as_uneven_splits(options)
209        };
210        pin_mut!(row_stream);
211        #[for_await]
212        for row in row_stream {
213            let row = row?;
214            yield row;
215        }
216    }
217
218    fn split_snapshot_read(
219        &self,
220        table_name: SchemaTableName,
221        left: OwnedRow,
222        right: OwnedRow,
223        split_columns: Vec<Field>,
224    ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
225        assert_eq!(table_name, self.schema_table_name);
226        self.split_snapshot_read_inner(table_name, left, right, split_columns)
227    }
228}
229
230impl PostgresExternalTableReader {
231    pub async fn new(
232        config: ExternalTableConfig,
233        rw_schema: Schema,
234        pk_indices: Vec<usize>,
235        schema_table_name: SchemaTableName,
236    ) -> ConnectorResult<Self> {
237        tracing::info!(
238            ?rw_schema,
239            ?pk_indices,
240            "create postgres external table reader"
241        );
242        let client = create_pg_client(
243            &config.username,
244            &config.password,
245            &config.host,
246            &config.port,
247            &config.database,
248            &config.ssl_mode,
249            &config.ssl_root_cert,
250            None, // No TCP keepalive for CDC source
251        )
252        .await?;
253
254        // Discover user-defined composite columns. tokio-postgres cannot decode
255        // composite values natively, so for these columns we cast to text in the
256        // snapshot SELECT to get the `(a,b,c)` textual representation. Other
257        // varchar columns stay as-is to preserve RW's own text rendering
258        // (e.g. numeric -> "NaN"/"POSITIVE_INFINITY").
259        let composite_columns = client
260            .query(
261                "SELECT a.attname \
262                 FROM pg_attribute a \
263                 JOIN pg_class c ON a.attrelid = c.oid \
264                 JOIN pg_namespace n ON c.relnamespace = n.oid \
265                 JOIN pg_type t ON a.atttypid = t.oid \
266                 WHERE n.nspname = $1 \
267                   AND c.relname = $2 \
268                   AND a.attnum > 0 \
269                   AND NOT a.attisdropped \
270                   AND t.typtype = 'c'",
271                &[
272                    &schema_table_name.schema_name,
273                    &schema_table_name.table_name,
274                ],
275            )
276            .await
277            .map(|rows| {
278                rows.into_iter()
279                    .map(|row| row.get::<_, String>(0))
280                    .collect::<std::collections::HashSet<_>>()
281            })
282            .unwrap_or_else(|err| {
283                tracing::warn!(
284                    error = %err.as_report(),
285                    schema = %schema_table_name.schema_name,
286                    table = %schema_table_name.table_name,
287                    "failed to discover postgres composite columns; falling back to no text cast"
288                );
289                std::collections::HashSet::new()
290            });
291
292        let field_names = rw_schema
293            .fields
294            .iter()
295            .map(|f| {
296                let quoted = Self::quote_column(&f.name);
297                if matches!(f.data_type, DataType::Varchar) && composite_columns.contains(&f.name) {
298                    format!("{quoted}::text AS {quoted}")
299                } else {
300                    quoted
301                }
302            })
303            .join(",");
304
305        Ok(Self {
306            rw_schema,
307            field_names,
308            pk_indices,
309            client: tokio::sync::Mutex::new(client),
310            schema_table_name,
311        })
312    }
313
314    pub fn get_normalized_table_name(table_name: &SchemaTableName) -> String {
315        format!(
316            "\"{}\".\"{}\"",
317            table_name.schema_name, table_name.table_name
318        )
319    }
320
321    pub fn get_cdc_offset_parser() -> CdcOffsetParseFunc {
322        Box::new(move |offset| {
323            Ok(CdcOffset::Postgres(PostgresOffset::parse_debezium_offset(
324                offset,
325            )?))
326        })
327    }
328
329    #[try_stream(boxed, ok = OwnedRow, error = ConnectorError)]
330    async fn snapshot_read_inner(
331        &self,
332        table_name: SchemaTableName,
333        start_pk_row: Option<OwnedRow>,
334        primary_keys: Vec<String>,
335        scan_limit: u32,
336    ) {
337        let order_key = Self::get_order_key(&primary_keys);
338        let client = self.client.lock().await;
339        client.execute("set time zone '+00:00'", &[]).await?;
340
341        let stream = match start_pk_row {
342            Some(ref pk_row) => {
343                // prepare the scan statement, since we may need to convert the RW data type to postgres data type
344                // e.g. varchar to uuid
345                let prepared_scan_stmt = {
346                    let primary_keys = self
347                        .pk_indices
348                        .iter()
349                        .map(|i| self.rw_schema.fields[*i].name.clone())
350                        .collect_vec();
351
352                    let order_key = Self::get_order_key(&primary_keys);
353                    let scan_sql = format!(
354                        "SELECT {} FROM {} WHERE {} ORDER BY {} LIMIT {scan_limit}",
355                        self.field_names,
356                        Self::get_normalized_table_name(&table_name),
357                        Self::filter_expression(&primary_keys),
358                        order_key,
359                    );
360                    client.prepare(&scan_sql).await?
361                };
362
363                let params: Vec<Option<ScalarAdapter>> = pk_row
364                    .iter()
365                    .zip_eq_fast(prepared_scan_stmt.params())
366                    .map(|(datum, ty)| {
367                        datum
368                            .map(|scalar| ScalarAdapter::from_scalar(scalar, ty))
369                            .transpose()
370                    })
371                    .try_collect()?;
372
373                client.query_raw(&prepared_scan_stmt, &params).await?
374            }
375            None => {
376                let sql = format!(
377                    "SELECT {} FROM {} ORDER BY {} LIMIT {scan_limit}",
378                    self.field_names,
379                    Self::get_normalized_table_name(&table_name),
380                    order_key,
381                );
382                let params: Vec<Option<ScalarAdapter>> = vec![];
383                client.query_raw(&sql, &params).await?
384            }
385        };
386
387        let row_stream = stream.map(|row| {
388            let row = row?;
389            Ok::<_, crate::error::ConnectorError>(postgres_row_to_owned_row(row, &self.rw_schema))
390        });
391
392        pin_mut!(row_stream);
393        #[for_await]
394        for row in row_stream {
395            let row = row?;
396            yield row;
397        }
398    }
399
400    // row filter expression: (v1, v2, v3) > ($1, $2, $3)
401    fn filter_expression(columns: &[String]) -> String {
402        let mut col_expr = String::new();
403        let mut arg_expr = String::new();
404        for (i, column) in columns.iter().enumerate() {
405            if i > 0 {
406                col_expr.push_str(", ");
407                arg_expr.push_str(", ");
408            }
409            col_expr.push_str(&Self::quote_column(column));
410            arg_expr.push_str(format!("${}", i + 1).as_str());
411        }
412        format!("({}) > ({})", col_expr, arg_expr)
413    }
414
415    // row filter expression: (v1, v2, v3) >= ($1, $2, $3) AND (v1, v2, v3) < ($1, $2, $3)
416    fn split_filter_expression(
417        columns: &[String],
418        is_first_split: bool,
419        is_last_split: bool,
420    ) -> String {
421        let mut left_col_expr = String::new();
422        let mut left_arg_expr = String::new();
423        let mut right_col_expr = String::new();
424        let mut right_arg_expr = String::new();
425        let mut c = 1;
426        if !is_first_split {
427            for (i, column) in columns.iter().enumerate() {
428                if i > 0 {
429                    left_col_expr.push_str(", ");
430                    left_arg_expr.push_str(", ");
431                }
432                left_col_expr.push_str(&Self::quote_column(column));
433                left_arg_expr.push_str(format!("${}", c).as_str());
434                c += 1;
435            }
436        }
437        if !is_last_split {
438            for (i, column) in columns.iter().enumerate() {
439                if i > 0 {
440                    right_col_expr.push_str(", ");
441                    right_arg_expr.push_str(", ");
442                }
443                right_col_expr.push_str(&Self::quote_column(column));
444                right_arg_expr.push_str(format!("${}", c).as_str());
445                c += 1;
446            }
447        }
448        if is_first_split && is_last_split {
449            "1 = 1".to_owned()
450        } else if is_first_split {
451            format!("({}) < ({})", right_col_expr, right_arg_expr,)
452        } else if is_last_split {
453            format!("({}) >= ({})", left_col_expr, left_arg_expr,)
454        } else {
455            format!(
456                "({}) >= ({}) AND ({}) < ({})",
457                left_col_expr, left_arg_expr, right_col_expr, right_arg_expr,
458            )
459        }
460    }
461
462    fn get_order_key(primary_keys: &Vec<String>) -> String {
463        primary_keys
464            .iter()
465            .map(|col| Self::quote_column(col))
466            .join(",")
467    }
468
469    fn quote_column(column: &str) -> String {
470        format!("\"{}\"", column)
471    }
472
473    async fn min_and_max(
474        &self,
475        split_column: &Field,
476    ) -> ConnectorResult<Option<(ScalarImpl, ScalarImpl)>> {
477        let sql = format!(
478            "SELECT MIN({}), MAX({}) FROM {}",
479            split_column.name,
480            split_column.name,
481            Self::get_normalized_table_name(&self.schema_table_name),
482        );
483        let client = self.client.lock().await;
484        let rows = client.query(&sql, &[]).await?;
485        if rows.is_empty() {
486            Ok(None)
487        } else {
488            let row = &rows[0];
489            let min =
490                postgres_cell_to_scalar_impl(row, &split_column.data_type, 0, &split_column.name);
491            let max =
492                postgres_cell_to_scalar_impl(row, &split_column.data_type, 1, &split_column.name);
493            match (min, max) {
494                (Some(min), Some(max)) => Ok(Some((min, max))),
495                _ => Ok(None),
496            }
497        }
498    }
499
500    async fn next_split_right_bound_exclusive(
501        &self,
502        left_value: &ScalarImpl,
503        max_value: &ScalarImpl,
504        max_split_size: u64,
505        split_column: &Field,
506    ) -> ConnectorResult<Option<Datum>> {
507        let sql = format!(
508            "WITH t as (SELECT {} FROM {} WHERE {} >= $1 ORDER BY {} ASC LIMIT {}) SELECT CASE WHEN MAX({}) < $2 THEN MAX({}) ELSE NULL END FROM t",
509            Self::quote_column(&split_column.name),
510            Self::get_normalized_table_name(&self.schema_table_name),
511            Self::quote_column(&split_column.name),
512            Self::quote_column(&split_column.name),
513            max_split_size,
514            Self::quote_column(&split_column.name),
515            Self::quote_column(&split_column.name),
516        );
517        let client = self.client.lock().await;
518        let prepared_stmt = client.prepare(&sql).await?;
519        let params: Vec<Option<ScalarAdapter>> = vec![
520            Some(ScalarAdapter::from_scalar(
521                left_value.as_scalar_ref_impl(),
522                &prepared_stmt.params()[0],
523            )?),
524            Some(ScalarAdapter::from_scalar(
525                max_value.as_scalar_ref_impl(),
526                &prepared_stmt.params()[1],
527            )?),
528        ];
529        let stream = client.query_raw(&prepared_stmt, &params).await?;
530        let datum_stream = stream.map(|row| {
531            let row = row?;
532            Ok::<_, ConnectorError>(postgres_cell_to_scalar_impl(
533                &row,
534                &split_column.data_type,
535                0,
536                &split_column.name,
537            ))
538        });
539        pin_mut!(datum_stream);
540        #[for_await]
541        for datum in datum_stream {
542            let right = datum?;
543            return Ok(Some(right.to_owned_datum()));
544        }
545        Ok(None)
546    }
547
548    async fn next_greater_bound(
549        &self,
550        start_offset: &ScalarImpl,
551        max_value: &ScalarImpl,
552        split_column: &Field,
553    ) -> ConnectorResult<Option<Datum>> {
554        let sql = format!(
555            "SELECT MIN({}) FROM {} WHERE {} > $1 AND {} <$2",
556            Self::quote_column(&split_column.name),
557            Self::get_normalized_table_name(&self.schema_table_name),
558            Self::quote_column(&split_column.name),
559            Self::quote_column(&split_column.name),
560        );
561        let client = self.client.lock().await;
562        let prepared_stmt = client.prepare(&sql).await?;
563        let params: Vec<Option<ScalarAdapter>> = vec![
564            Some(ScalarAdapter::from_scalar(
565                start_offset.as_scalar_ref_impl(),
566                &prepared_stmt.params()[0],
567            )?),
568            Some(ScalarAdapter::from_scalar(
569                max_value.as_scalar_ref_impl(),
570                &prepared_stmt.params()[1],
571            )?),
572        ];
573        let stream = client.query_raw(&prepared_stmt, &params).await?;
574        let datum_stream = stream.map(|row| {
575            let row = row?;
576            Ok::<_, ConnectorError>(postgres_cell_to_scalar_impl(
577                &row,
578                &split_column.data_type,
579                0,
580                &split_column.name,
581            ))
582        });
583        pin_mut!(datum_stream);
584        #[for_await]
585        for datum in datum_stream {
586            let right = datum?;
587            return Ok(Some(right));
588        }
589        Ok(None)
590    }
591
592    #[try_stream(boxed, ok = OwnedRow, error = ConnectorError)]
593    async fn split_snapshot_read_inner(
594        &self,
595        table_name: SchemaTableName,
596        left: OwnedRow,
597        right: OwnedRow,
598        split_columns: Vec<Field>,
599    ) {
600        assert_eq!(
601            split_columns.len(),
602            1,
603            "multiple split columns is not supported yet"
604        );
605        assert_eq!(left.len(), 1, "multiple split columns is not supported yet");
606        assert_eq!(
607            right.len(),
608            1,
609            "multiple split columns is not supported yet"
610        );
611        let is_first_split = left[0].is_none();
612        let is_last_split = right[0].is_none();
613        let split_column_names = split_columns.iter().map(|c| c.name.clone()).collect_vec();
614        let client = self.client.lock().await;
615        client.execute("set time zone '+00:00'", &[]).await?;
616        // prepare the scan statement, since we may need to convert the RW data type to postgres data type
617        // e.g. varchar to uuid
618        let prepared_scan_stmt = {
619            let scan_sql = format!(
620                "SELECT {} FROM {} WHERE {}",
621                self.field_names,
622                Self::get_normalized_table_name(&table_name),
623                Self::split_filter_expression(&split_column_names, is_first_split, is_last_split),
624            );
625            client.prepare(&scan_sql).await?
626        };
627
628        let mut params: Vec<Option<ScalarAdapter>> = vec![];
629        if !is_first_split {
630            let left_params: Vec<Option<ScalarAdapter>> = left
631                .iter()
632                .zip_eq_fast(prepared_scan_stmt.params().iter().take(left.len()))
633                .map(|(datum, ty)| {
634                    datum
635                        .map(|scalar| ScalarAdapter::from_scalar(scalar, ty))
636                        .transpose()
637                })
638                .try_collect()?;
639            params.extend(left_params);
640        }
641        if !is_last_split {
642            let right_params: Vec<Option<ScalarAdapter>> = right
643                .iter()
644                .zip_eq_fast(prepared_scan_stmt.params().iter().skip(params.len()))
645                .map(|(datum, ty)| {
646                    datum
647                        .map(|scalar| ScalarAdapter::from_scalar(scalar, ty))
648                        .transpose()
649                })
650                .try_collect()?;
651            params.extend(right_params);
652        }
653
654        let stream = client.query_raw(&prepared_scan_stmt, &params).await?;
655        let row_stream = stream.map(|row| {
656            let row = row?;
657            Ok::<_, crate::error::ConnectorError>(postgres_row_to_owned_row(row, &self.rw_schema))
658        });
659
660        pin_mut!(row_stream);
661        #[for_await]
662        for row in row_stream {
663            let row = row?;
664            yield row;
665        }
666    }
667
668    #[try_stream(boxed, ok = CdcTableSnapshotSplit, error = ConnectorError)]
669    async fn as_uneven_splits(&self, options: CdcTableSnapshotSplitOption) {
670        let split_column = self.split_column(&options);
671        let mut split_id = CDC_TABLE_SPLIT_ID_START;
672        let Some((min_value, max_value)) = self.min_and_max(&split_column).await? else {
673            let left_bound_row = OwnedRow::new(vec![None]);
674            let right_bound_row = OwnedRow::new(vec![None]);
675            let split = CdcTableSnapshotSplit {
676                split_id,
677                left_bound_inclusive: left_bound_row,
678                right_bound_exclusive: right_bound_row,
679            };
680            yield split;
681            return Ok(());
682        };
683        // left bound will never be NULL value.
684        let mut next_left_bound_inclusive = min_value.clone();
685        loop {
686            let left_bound_inclusive: Datum = if next_left_bound_inclusive == min_value {
687                None
688            } else {
689                Some(next_left_bound_inclusive.clone())
690            };
691            let right_bound_exclusive;
692            let mut next_right = self
693                .next_split_right_bound_exclusive(
694                    &next_left_bound_inclusive,
695                    &max_value,
696                    options.backfill_num_rows_per_split,
697                    &split_column,
698                )
699                .await?;
700            if let Some(Some(ref inner)) = next_right
701                && *inner == next_left_bound_inclusive
702            {
703                next_right = self
704                    .next_greater_bound(&next_left_bound_inclusive, &max_value, &split_column)
705                    .await?;
706            }
707            if let Some(next_right) = next_right {
708                match next_right {
709                    None => {
710                        // NULL found.
711                        right_bound_exclusive = None;
712                    }
713                    Some(next_right) => {
714                        next_left_bound_inclusive = next_right.clone();
715                        right_bound_exclusive = Some(next_right);
716                    }
717                }
718            } else {
719                // Not found.
720                right_bound_exclusive = None;
721            };
722            let is_completed = right_bound_exclusive.is_none();
723            if is_completed && left_bound_inclusive.is_none() {
724                assert_eq!(split_id, 1);
725            }
726            tracing::info!(
727                split_id,
728                ?left_bound_inclusive,
729                ?right_bound_exclusive,
730                "New CDC table snapshot split."
731            );
732            let left_bound_row = OwnedRow::new(vec![left_bound_inclusive]);
733            let right_bound_row = OwnedRow::new(vec![right_bound_exclusive]);
734            let split = CdcTableSnapshotSplit {
735                split_id,
736                left_bound_inclusive: left_bound_row,
737                right_bound_exclusive: right_bound_row,
738            };
739            try_increase_split_id(&mut split_id)?;
740            yield split;
741            if is_completed {
742                break;
743            }
744        }
745    }
746
747    #[try_stream(boxed, ok = CdcTableSnapshotSplit, error = ConnectorError)]
748    async fn as_even_splits(&self, options: CdcTableSnapshotSplitOption) {
749        let split_column = self.split_column(&options);
750        let mut split_id = 1;
751        let Some((min_value, max_value)) = self.min_and_max(&split_column).await? else {
752            let left_bound_row = OwnedRow::new(vec![None]);
753            let right_bound_row = OwnedRow::new(vec![None]);
754            let split = CdcTableSnapshotSplit {
755                split_id,
756                left_bound_inclusive: left_bound_row,
757                right_bound_exclusive: right_bound_row,
758            };
759            yield split;
760            return Ok(());
761        };
762        let min_value = min_value.as_integral();
763        let max_value = max_value.as_integral();
764        let saturated_split_max_size = options
765            .backfill_num_rows_per_split
766            .try_into()
767            .unwrap_or(i64::MAX);
768        let mut left = None;
769        let mut right = Some(min_value.saturating_add(saturated_split_max_size));
770        loop {
771            let mut is_completed = false;
772            if right.as_ref().map(|r| *r >= max_value).unwrap_or(true) {
773                right = None;
774                is_completed = true;
775            }
776            let split = CdcTableSnapshotSplit {
777                split_id,
778                left_bound_inclusive: OwnedRow::new(vec![
779                    left.map(|l| to_int_scalar(l, &split_column.data_type)),
780                ]),
781                right_bound_exclusive: OwnedRow::new(vec![
782                    right.map(|r| to_int_scalar(r, &split_column.data_type)),
783                ]),
784            };
785            try_increase_split_id(&mut split_id)?;
786            yield split;
787            if is_completed {
788                break;
789            }
790            left = right;
791            right = left.map(|l| l.saturating_add(saturated_split_max_size));
792        }
793    }
794
795    fn split_column(&self, options: &CdcTableSnapshotSplitOption) -> Field {
796        self.rw_schema.fields[self.pk_indices[options.backfill_split_pk_column_index as usize]]
797            .clone()
798    }
799}
800
801fn to_int_scalar(i: i64, data_type: &DataType) -> ScalarImpl {
802    match data_type {
803        DataType::Int16 => ScalarImpl::Int16(i.try_into().unwrap()),
804        DataType::Int32 => ScalarImpl::Int32(i.try_into().unwrap()),
805        DataType::Int64 => ScalarImpl::Int64(i),
806        _ => {
807            panic!("Can't convert int {} to ScalarImpl::{}", i, data_type)
808        }
809    }
810}
811
812fn try_increase_split_id(split_id: &mut i64) -> ConnectorResult<()> {
813    match split_id.checked_add(1) {
814        Some(s) => {
815            *split_id = s;
816            Ok(())
817        }
818        None => Err(anyhow::anyhow!("too many CDC snapshot splits").into()),
819    }
820}
821
822/// Use the first column of primary keys to split table.
823fn is_supported_even_split_data_type(data_type: &DataType) -> bool {
824    matches!(
825        data_type,
826        DataType::Int16 | DataType::Int32 | DataType::Int64
827    )
828}
829
830pub fn type_name_to_pg_type(ty_name: &str) -> Option<PgType> {
831    let ty_name_lower = ty_name.to_lowercase();
832    // Handle array types (prefixed with _)
833    if let Some(base_type) = ty_name_lower.strip_prefix('_') {
834        match base_type {
835            "int2" => Some(PgType::INT2_ARRAY),
836            "int4" => Some(PgType::INT4_ARRAY),
837            "int8" => Some(PgType::INT8_ARRAY),
838            "bit" => Some(PgType::BIT_ARRAY),
839            "float4" => Some(PgType::FLOAT4_ARRAY),
840            "float8" => Some(PgType::FLOAT8_ARRAY),
841            "numeric" => Some(PgType::NUMERIC_ARRAY),
842            "bool" => Some(PgType::BOOL_ARRAY),
843            "xml" | "macaddr" | "macaddr8" | "cidr" | "inet" | "int4range" | "int8range"
844            | "numrange" | "tsrange" | "tstzrange" | "daterange" | "citext" => {
845                Some(PgType::VARCHAR_ARRAY)
846            }
847            "varchar" => Some(PgType::VARCHAR_ARRAY),
848            "text" => Some(PgType::TEXT_ARRAY),
849            "bytea" => Some(PgType::BYTEA_ARRAY),
850            "geometry" => Some(PgType::BYTEA_ARRAY), // PostGIS geometry array
851            "date" => Some(PgType::DATE_ARRAY),
852            "time" => Some(PgType::TIME_ARRAY),
853            "timetz" => Some(PgType::TIMETZ_ARRAY),
854            "timestamp" => Some(PgType::TIMESTAMP_ARRAY),
855            "timestamptz" => Some(PgType::TIMESTAMPTZ_ARRAY),
856            "interval" => Some(PgType::INTERVAL_ARRAY),
857            "json" => Some(PgType::JSON_ARRAY),
858            "jsonb" => Some(PgType::JSONB_ARRAY),
859            "uuid" => Some(PgType::UUID_ARRAY),
860            "point" => Some(PgType::POINT_ARRAY),
861            "oid" => Some(PgType::OID_ARRAY),
862            "money" => Some(PgType::MONEY_ARRAY),
863            _ => None,
864        }
865    } else {
866        // Handle non-array types
867        match ty_name_lower.as_str() {
868            "int2" => Some(PgType::INT2),
869            "bit" => Some(PgType::BIT),
870            "int" | "int4" => Some(PgType::INT4),
871            "int8" => Some(PgType::INT8),
872            "float4" => Some(PgType::FLOAT4),
873            "float8" => Some(PgType::FLOAT8),
874            "numeric" => Some(PgType::NUMERIC),
875            "money" => Some(PgType::MONEY),
876            "boolean" | "bool" => Some(PgType::BOOL),
877            "inet" | "xml" | "varchar" | "character varying" | "int4range" | "int8range"
878            | "numrange" | "tsrange" | "tstzrange" | "daterange" | "macaddr" | "macaddr8"
879            | "cidr" => Some(PgType::VARCHAR),
880            "char" | "character" | "bpchar" => Some(PgType::BPCHAR),
881            "citext" | "text" => Some(PgType::TEXT),
882            "bytea" => Some(PgType::BYTEA),
883            "geometry" => Some(PgType::BYTEA), // PostGIS geometry type
884            "date" => Some(PgType::DATE),
885            "time" => Some(PgType::TIME),
886            "timetz" => Some(PgType::TIMETZ),
887            "timestamp" => Some(PgType::TIMESTAMP),
888            "timestamptz" => Some(PgType::TIMESTAMPTZ),
889            "interval" => Some(PgType::INTERVAL),
890            "json" => Some(PgType::JSON),
891            "jsonb" => Some(PgType::JSONB),
892            "uuid" => Some(PgType::UUID),
893            "point" => Some(PgType::POINT),
894            "oid" => Some(PgType::OID),
895            _ => None,
896        }
897    }
898}
899
900pub fn pg_type_to_rw_type(pg_type: &PgType) -> ConnectorResult<DataType> {
901    let data_type = match *pg_type {
902        PgType::BOOL => DataType::Boolean,
903        PgType::BIT => DataType::Boolean,
904        PgType::INT2 => DataType::Int16,
905        PgType::INT4 => DataType::Int32,
906        PgType::INT8 => DataType::Int64,
907        PgType::FLOAT4 => DataType::Float32,
908        PgType::FLOAT8 => DataType::Float64,
909        PgType::NUMERIC | PgType::MONEY => DataType::Decimal,
910        PgType::DATE => DataType::Date,
911        PgType::TIME => DataType::Time,
912        PgType::TIMETZ => DataType::Time,
913        PgType::POINT => DataType::Struct(risingwave_common::types::StructType::new(vec![
914            ("x", DataType::Float32),
915            ("y", DataType::Float32),
916        ])),
917        PgType::TIMESTAMP => DataType::Timestamp,
918        PgType::TIMESTAMPTZ => DataType::Timestamptz,
919        PgType::INTERVAL => DataType::Interval,
920        PgType::VARCHAR | PgType::TEXT | PgType::BPCHAR | PgType::UUID => DataType::Varchar,
921        PgType::BYTEA => DataType::Bytea,
922        PgType::JSON | PgType::JSONB => DataType::Jsonb,
923        // Array types
924        PgType::BOOL_ARRAY => DataType::Boolean.list(),
925        PgType::BIT_ARRAY => DataType::Boolean.list(),
926        PgType::INT2_ARRAY => DataType::Int16.list(),
927        PgType::INT4_ARRAY => DataType::Int32.list(),
928        PgType::INT8_ARRAY => DataType::Int64.list(),
929        PgType::FLOAT4_ARRAY => DataType::Float32.list(),
930        PgType::FLOAT8_ARRAY => DataType::Float64.list(),
931        PgType::NUMERIC_ARRAY => DataType::Decimal.list(),
932        PgType::VARCHAR_ARRAY => DataType::Varchar.list(),
933        PgType::TEXT_ARRAY => DataType::Varchar.list(),
934        PgType::BYTEA_ARRAY => DataType::Bytea.list(),
935        PgType::DATE_ARRAY => DataType::Date.list(),
936        PgType::TIME_ARRAY => DataType::Time.list(),
937        PgType::TIMESTAMP_ARRAY => DataType::Timestamp.list(),
938        PgType::TIMESTAMPTZ_ARRAY => DataType::Timestamptz.list(),
939        PgType::INTERVAL_ARRAY => DataType::Interval.list(),
940        PgType::JSON_ARRAY => DataType::Jsonb.list(),
941        PgType::JSONB_ARRAY => DataType::Jsonb.list(),
942        PgType::UUID_ARRAY => DataType::Varchar.list(),
943        PgType::OID => DataType::Int64,
944        PgType::OID_ARRAY => DataType::Int64.list(),
945        PgType::MONEY_ARRAY => DataType::Decimal.list(),
946        PgType::POINT_ARRAY => {
947            DataType::list(DataType::Struct(risingwave_common::types::StructType::new(
948                vec![("x", DataType::Float32), ("y", DataType::Float32)],
949            )))
950        }
951        _ => {
952            return Err(anyhow::anyhow!("unsupported postgres type: {}", pg_type).into());
953        }
954    };
955    Ok(data_type)
956}
957
958#[cfg(test)]
959mod tests {
960    use std::cmp::Ordering;
961    use std::collections::HashMap;
962
963    use futures::pin_mut;
964    use futures_async_stream::for_await;
965    use maplit::{convert_args, hashmap};
966    use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
967    use risingwave_common::row::OwnedRow;
968    use risingwave_common::types::{DataType, ScalarImpl};
969
970    use crate::connector_common::PostgresExternalTable;
971    use crate::source::cdc::external::postgres::{PostgresExternalTableReader, PostgresOffset};
972    use crate::source::cdc::external::{ExternalTableConfig, ExternalTableReader, SchemaTableName};
973
974    #[ignore]
975    #[tokio::test]
976    async fn test_postgres_schema() {
977        let config = ExternalTableConfig {
978            connector: "postgres-cdc".to_owned(),
979            host: "localhost".to_owned(),
980            port: "8432".to_owned(),
981            username: "myuser".to_owned(),
982            password: "123456".to_owned(),
983            database: "mydb".to_owned(),
984            schema: "public".to_owned(),
985            table: "mytest".to_owned(),
986            ssl_mode: Default::default(),
987            ssl_root_cert: None,
988            encrypt: "false".to_owned(),
989        };
990
991        let table = PostgresExternalTable::connect(
992            &config.username,
993            &config.password,
994            &config.host,
995            config.port.parse::<u16>().unwrap(),
996            &config.database,
997            &config.schema,
998            &config.table,
999            &config.ssl_mode,
1000            &config.ssl_root_cert,
1001            false,
1002        )
1003        .await
1004        .unwrap();
1005
1006        println!("columns: {:?}", &table.column_descs());
1007        println!("primary keys: {:?}", &table.pk_names());
1008    }
1009
1010    #[test]
1011    fn test_postgres_offset() {
1012        let off1 = PostgresOffset {
1013            txid: 4,
1014            lsn: 2,
1015            ..Default::default()
1016        };
1017        let off2 = PostgresOffset {
1018            txid: 1,
1019            lsn: 3,
1020            ..Default::default()
1021        };
1022        let off3 = PostgresOffset {
1023            txid: 5,
1024            lsn: 1,
1025            ..Default::default()
1026        };
1027
1028        assert!(off1 < off2);
1029        assert!(off3 < off1);
1030        assert!(off2 > off3);
1031    }
1032
1033    #[test]
1034    fn test_postgres_offset_partial_ord_with_lsn_commit() {
1035        // Test comparison with both lsn_commit and lsn_proc fields
1036        let off1 = PostgresOffset {
1037            txid: 1,
1038            lsn: 100,
1039            lsn_commit: Some(200),
1040            lsn_proc: Some(150),
1041        };
1042        let off2 = PostgresOffset {
1043            txid: 2,
1044            lsn: 300,
1045            lsn_commit: Some(250),
1046            lsn_proc: Some(200),
1047        };
1048
1049        // Should compare using lsn_commit first when both have both fields
1050        assert!(off1 < off2);
1051
1052        // Test with same lsn_commit but different lsn_proc
1053        let off3 = PostgresOffset {
1054            txid: 3,
1055            lsn: 500,
1056            lsn_commit: Some(200), // same as off1
1057            lsn_proc: Some(160),   // higher than off1
1058        };
1059
1060        // Should compare lsn_proc when lsn_commit is equal
1061        assert!(off1 < off3);
1062
1063        // Test with missing lsn_proc - should fall back to lsn comparison
1064        let off4 = PostgresOffset {
1065            txid: 4,
1066            lsn: 400,
1067            lsn_commit: Some(100), // lower than off1's lsn_commit
1068            lsn_proc: None,        // missing lsn_proc
1069        };
1070
1071        // Should fall back to lsn comparison (off1.lsn=100 < off4.lsn=400)
1072        assert!(off1 < off4);
1073
1074        // Test with missing lsn_commit - should fall back to lsn comparison
1075        let off5 = PostgresOffset {
1076            txid: 5,
1077            lsn: 50,             // lower than off1.lsn
1078            lsn_commit: None,    // missing lsn_commit
1079            lsn_proc: Some(300), // higher than off1's lsn_proc
1080        };
1081
1082        // Should fall back to lsn comparison (off5.lsn=50 < off1.lsn=100)
1083        assert!(off5 < off1);
1084
1085        // Additional test cases: equal lsn_commit values with different lsn_proc
1086        let off6 = PostgresOffset {
1087            txid: 6,
1088            lsn: 600,
1089            lsn_commit: Some(500),
1090            lsn_proc: Some(300),
1091        };
1092        let off7 = PostgresOffset {
1093            txid: 7,
1094            lsn: 700,
1095            lsn_commit: Some(500), // same as off6
1096            lsn_proc: Some(400),   // higher than off6
1097        };
1098
1099        // Should compare lsn_proc since lsn_commit is equal
1100        assert!(off6 < off7);
1101
1102        // Test reverse order
1103        let off8 = PostgresOffset {
1104            txid: 8,
1105            lsn: 800,
1106            lsn_commit: Some(500), // same as others
1107            lsn_proc: Some(200),   // lower than off6
1108        };
1109
1110        assert!(off8 < off6);
1111        assert!(off8 < off7);
1112
1113        // Test equal lsn_commit and lsn_proc
1114        let off9 = PostgresOffset {
1115            txid: 9,
1116            lsn: 900,
1117            lsn_commit: Some(500), // same as off6
1118            lsn_proc: Some(300),   // same as off6
1119        };
1120
1121        // Should be equal
1122        assert_eq!(off6.partial_cmp(&off9), Some(Ordering::Equal));
1123    }
1124
1125    #[test]
1126    fn test_debezium_offset_parsing() {
1127        // Test parsing with all required fields present
1128        let debezium_offset_with_fields = r#"{
1129            "sourcePartition": {"server": "RW_CDC_1004"},
1130            "sourceOffset": {
1131                "last_snapshot_record": false,
1132                "lsn": 29973552,
1133                "txId": 1046,
1134                "ts_usec": 1670826189008456,
1135                "snapshot": true,
1136                "lsn_commit": 29973600,
1137                "lsn_proc": 29973580
1138            },
1139            "isHeartbeat": false
1140        }"#;
1141
1142        let offset = PostgresOffset::parse_debezium_offset(debezium_offset_with_fields).unwrap();
1143        assert_eq!(offset.txid, 1046);
1144        assert_eq!(offset.lsn, 29973552);
1145        assert_eq!(offset.lsn_commit, Some(29973600));
1146        assert_eq!(offset.lsn_proc, Some(29973580));
1147
1148        // Test parsing should fail when required fields are missing
1149        let debezium_offset_missing_fields = r#"{
1150            "sourcePartition": {"server": "RW_CDC_1004"},
1151            "sourceOffset": {
1152                "last_snapshot_record": false,
1153                "lsn": 29973552,
1154                "txId": 1046,
1155                "ts_usec": 1670826189008456,
1156                "snapshot": true
1157            },
1158            "isHeartbeat": false
1159        }"#;
1160
1161        let result = PostgresOffset::parse_debezium_offset(debezium_offset_missing_fields);
1162        assert!(result.is_err());
1163        let error_msg = result.unwrap_err().to_string();
1164        assert!(error_msg.contains("invalid postgres lsn_proc"));
1165    }
1166
1167    #[test]
1168    fn test_filter_expression() {
1169        let cols = vec!["v1".to_owned()];
1170        let expr = PostgresExternalTableReader::filter_expression(&cols);
1171        assert_eq!(expr, "(\"v1\") > ($1)");
1172
1173        let cols = vec!["v1".to_owned(), "v2".to_owned()];
1174        let expr = PostgresExternalTableReader::filter_expression(&cols);
1175        assert_eq!(expr, "(\"v1\", \"v2\") > ($1, $2)");
1176
1177        let cols = vec!["v1".to_owned(), "v2".to_owned(), "v3".to_owned()];
1178        let expr = PostgresExternalTableReader::filter_expression(&cols);
1179        assert_eq!(expr, "(\"v1\", \"v2\", \"v3\") > ($1, $2, $3)");
1180    }
1181
1182    #[test]
1183    fn test_split_filter_expression() {
1184        let cols = vec!["v1".to_owned()];
1185        let expr = PostgresExternalTableReader::split_filter_expression(&cols, true, true);
1186        assert_eq!(expr, "1 = 1");
1187
1188        let expr = PostgresExternalTableReader::split_filter_expression(&cols, true, false);
1189        assert_eq!(expr, "(\"v1\") < ($1)");
1190
1191        let expr = PostgresExternalTableReader::split_filter_expression(&cols, false, true);
1192        assert_eq!(expr, "(\"v1\") >= ($1)");
1193
1194        let expr = PostgresExternalTableReader::split_filter_expression(&cols, false, false);
1195        assert_eq!(expr, "(\"v1\") >= ($1) AND (\"v1\") < ($2)");
1196    }
1197
1198    // manual test
1199    #[ignore]
1200    #[tokio::test]
1201    async fn test_pg_table_reader() {
1202        let columns = [
1203            ColumnDesc::named("v1", ColumnId::new(1), DataType::Int32),
1204            ColumnDesc::named("v2", ColumnId::new(2), DataType::Varchar),
1205            ColumnDesc::named("v3", ColumnId::new(3), DataType::Decimal),
1206            ColumnDesc::named("v4", ColumnId::new(4), DataType::Date),
1207        ];
1208        let rw_schema = Schema {
1209            fields: columns.iter().map(Field::from).collect(),
1210        };
1211
1212        let props: HashMap<String, String> = convert_args!(hashmap!(
1213                "hostname" => "localhost",
1214                "port" => "8432",
1215                "username" => "myuser",
1216                "password" => "123456",
1217                "database.name" => "mydb",
1218                "schema.name" => "public",
1219                "table.name" => "t1"));
1220
1221        let config =
1222            serde_json::from_value::<ExternalTableConfig>(serde_json::to_value(props).unwrap())
1223                .unwrap();
1224        let schema_table_name = SchemaTableName {
1225            schema_name: "public".to_owned(),
1226            table_name: "t1".to_owned(),
1227        };
1228        let reader = PostgresExternalTableReader::new(
1229            config,
1230            rw_schema,
1231            vec![0, 1],
1232            schema_table_name.clone(),
1233        )
1234        .await
1235        .unwrap();
1236
1237        let offset = reader.current_cdc_offset().await.unwrap();
1238        println!("CdcOffset: {:?}", offset);
1239
1240        let start_pk = OwnedRow::new(vec![Some(ScalarImpl::from(3)), Some(ScalarImpl::from("c"))]);
1241        let stream = reader.snapshot_read(
1242            schema_table_name,
1243            Some(start_pk),
1244            vec!["v1".to_owned(), "v2".to_owned()],
1245            1000,
1246        );
1247
1248        pin_mut!(stream);
1249        #[for_await]
1250        for row in stream {
1251            println!("OwnedRow: {:?}", row);
1252        }
1253    }
1254}