risingwave_connector/source/cdc/external/
postgres.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::sync::LazyLock;
17
18use anyhow::Context;
19use futures::stream::BoxStream;
20use futures::{StreamExt, pin_mut};
21use futures_async_stream::{for_await, try_stream};
22use itertools::Itertools;
23use risingwave_common::catalog::{Field, Schema};
24use risingwave_common::log::LogSuppressor;
25use risingwave_common::row::{OwnedRow, Row};
26use risingwave_common::types::{DataType, Datum, ScalarImpl, ToOwnedDatum};
27use risingwave_common::util::iter_util::ZipEqFast;
28use serde::{Deserialize, Serialize};
29use tokio_postgres::types::{PgLsn, Type as PgType};
30
31use crate::connector_common::create_pg_client;
32use crate::error::{ConnectorError, ConnectorResult};
33use crate::parser::scalar_adapter::ScalarAdapter;
34use crate::parser::{postgres_cell_to_scalar_impl, postgres_row_to_owned_row};
35use crate::source::CdcTableSnapshotSplit;
36use crate::source::cdc::external::{
37    CDC_TABLE_SPLIT_ID_START, CdcOffset, CdcOffsetParseFunc, CdcTableSnapshotSplitOption,
38    DebeziumOffset, ExternalTableConfig, ExternalTableReader, SchemaTableName,
39};
40
41#[derive(Debug, Clone, Default, Serialize, Deserialize)]
42pub struct PostgresOffset {
43    pub txid: i64,
44    // In postgres, an LSN is a 64-bit integer, representing a byte position in the write-ahead log stream.
45    // It is printed as two hexadecimal numbers of up to 8 digits each, separated by a slash; for example, 16/B374D848
46    pub lsn: u64,
47    // Additional LSN fields for improved tracking
48    #[serde(default)]
49    pub lsn_commit: Option<u64>,
50    #[serde(default)]
51    pub lsn_proc: Option<u64>,
52}
53
54impl PartialOrd for PostgresOffset {
55    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
56        Some(self.cmp(other))
57    }
58}
59
60impl Eq for PostgresOffset {}
61impl PartialEq for PostgresOffset {
62    fn eq(&self, other: &Self) -> bool {
63        match (
64            self.lsn_commit,
65            self.lsn_proc,
66            other.lsn_commit,
67            other.lsn_proc,
68        ) {
69            (_, Some(_), _, Some(_)) => {
70                self.lsn_commit == other.lsn_commit && self.lsn_proc == other.lsn_proc
71            }
72            _ => self.lsn == other.lsn,
73        }
74    }
75}
76
77// only compare the lsn field, prefer lsn_commit and lsn_proc if both available
78impl Ord for PostgresOffset {
79    fn cmp(&self, other: &Self) -> Ordering {
80        match (
81            self.lsn_commit,
82            self.lsn_proc,
83            other.lsn_commit,
84            other.lsn_proc,
85        ) {
86            (_, Some(self_proc), _, Some(other_proc)) => {
87                // if both have `lsn_commit` and `lsn_proc`, compare `lsn_commit` first, then `lsn_proc`
88                // if `lsn_commit` is None, fall back to `lsn_proc`
89                match self.lsn_commit.cmp(&other.lsn_commit) {
90                    Ordering::Equal => self_proc.cmp(&other_proc),
91                    other_result => other_result,
92                }
93            }
94            _ => {
95                // Fall back to lsn comparison when either lsn_commit or lsn_proc is missing
96                static LOG_SUPPRESSOR: LazyLock<LogSuppressor> =
97                    LazyLock::new(LogSuppressor::default);
98                if let Ok(suppressed_count) = LOG_SUPPRESSOR.check() {
99                    tracing::warn!(
100                        suppressed_count,
101                        self_lsn = self.lsn,
102                        other_lsn = other.lsn,
103                        "lsn_commit and lsn_proc are missing, fall back to lsn comparison"
104                    );
105                }
106                self.lsn.cmp(&other.lsn)
107            }
108        }
109    }
110}
111
112impl PostgresOffset {
113    pub fn parse_debezium_offset(offset: &str) -> ConnectorResult<Self> {
114        let dbz_offset: DebeziumOffset = serde_json::from_str(offset)
115            .with_context(|| format!("invalid upstream offset: {}", offset))?;
116
117        let lsn = dbz_offset
118            .source_offset
119            .lsn
120            .context("invalid postgres lsn")?;
121
122        // `lsn_commit` may not be present in the offset for the first tx.
123        let lsn_commit = dbz_offset.source_offset.lsn_commit;
124
125        let lsn_proc = dbz_offset
126            .source_offset
127            .lsn_proc
128            .context("invalid postgres lsn_proc")?;
129
130        Ok(Self {
131            txid: dbz_offset
132                .source_offset
133                .txid
134                .context("invalid postgres txid")?,
135            lsn,
136            lsn_commit,
137            lsn_proc: Some(lsn_proc),
138        })
139    }
140}
141
142pub struct PostgresExternalTableReader {
143    rw_schema: Schema,
144    field_names: String,
145    pk_indices: Vec<usize>,
146    client: tokio::sync::Mutex<tokio_postgres::Client>,
147    schema_table_name: SchemaTableName,
148}
149
150impl ExternalTableReader for PostgresExternalTableReader {
151    async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
152        let mut client = self.client.lock().await;
153        // start a transaction to read current lsn and txid
154        let trxn = client.transaction().await?;
155        let row = trxn.query_one("SELECT pg_current_wal_lsn()", &[]).await?;
156        let mut pg_offset = PostgresOffset::default();
157        let pg_lsn = row.get::<_, PgLsn>(0);
158        tracing::debug!("current lsn: {}", pg_lsn);
159        pg_offset.lsn = pg_lsn.into();
160
161        let txid_row = trxn.query_one("SELECT txid_current()", &[]).await?;
162        let txid: i64 = txid_row.get::<_, i64>(0);
163        pg_offset.txid = txid;
164
165        // commit the transaction
166        trxn.commit().await?;
167
168        Ok(CdcOffset::Postgres(pg_offset))
169    }
170
171    fn snapshot_read(
172        &self,
173        table_name: SchemaTableName,
174        start_pk: Option<OwnedRow>,
175        primary_keys: Vec<String>,
176        limit: u32,
177    ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
178        assert_eq!(table_name, self.schema_table_name);
179        self.snapshot_read_inner(table_name, start_pk, primary_keys, limit)
180    }
181
182    #[try_stream(boxed, ok = CdcTableSnapshotSplit, error = ConnectorError)]
183    async fn get_parallel_cdc_splits(&self, options: CdcTableSnapshotSplitOption) {
184        let backfill_num_rows_per_split = options.backfill_num_rows_per_split;
185        if backfill_num_rows_per_split == 0 {
186            return Err(anyhow::anyhow!(
187                "invalid backfill_num_rows_per_split, must be greater than 0"
188            )
189            .into());
190        }
191        if options.backfill_split_pk_column_index as usize >= self.pk_indices.len() {
192            return Err(anyhow::anyhow!(format!(
193                "invalid backfill_split_pk_column_index {}, out of bound",
194                options.backfill_split_pk_column_index
195            ))
196            .into());
197        }
198        let split_column = self.split_column(&options);
199        let row_stream = if options.backfill_as_even_splits
200            && is_supported_even_split_data_type(&split_column.data_type)
201        {
202            // For certain types, use evenly-sized partition to optimize performance.
203            tracing::info!(?self.schema_table_name, ?self.rw_schema, ?self.pk_indices, ?split_column, "Get parallel cdc table snapshot even splits.");
204            self.as_even_splits(options)
205        } else {
206            tracing::info!(?self.schema_table_name, ?self.rw_schema, ?self.pk_indices, ?split_column, "Get parallel cdc table snapshot uneven splits.");
207            self.as_uneven_splits(options)
208        };
209        pin_mut!(row_stream);
210        #[for_await]
211        for row in row_stream {
212            let row = row?;
213            yield row;
214        }
215    }
216
217    fn split_snapshot_read(
218        &self,
219        table_name: SchemaTableName,
220        left: OwnedRow,
221        right: OwnedRow,
222        split_columns: Vec<Field>,
223    ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
224        assert_eq!(table_name, self.schema_table_name);
225        self.split_snapshot_read_inner(table_name, left, right, split_columns)
226    }
227}
228
229impl PostgresExternalTableReader {
230    pub async fn new(
231        config: ExternalTableConfig,
232        rw_schema: Schema,
233        pk_indices: Vec<usize>,
234        schema_table_name: SchemaTableName,
235    ) -> ConnectorResult<Self> {
236        tracing::info!(
237            ?rw_schema,
238            ?pk_indices,
239            "create postgres external table reader"
240        );
241        let client = create_pg_client(
242            &config.username,
243            &config.password,
244            &config.host,
245            &config.port,
246            &config.database,
247            &config.ssl_mode,
248            &config.ssl_root_cert,
249            None, // No TCP keepalive for CDC source
250        )
251        .await?;
252        let field_names = rw_schema
253            .fields
254            .iter()
255            .map(|f| Self::quote_column(&f.name))
256            .join(",");
257
258        Ok(Self {
259            rw_schema,
260            field_names,
261            pk_indices,
262            client: tokio::sync::Mutex::new(client),
263            schema_table_name,
264        })
265    }
266
267    pub fn get_normalized_table_name(table_name: &SchemaTableName) -> String {
268        format!(
269            "\"{}\".\"{}\"",
270            table_name.schema_name, table_name.table_name
271        )
272    }
273
274    pub fn get_cdc_offset_parser() -> CdcOffsetParseFunc {
275        Box::new(move |offset| {
276            Ok(CdcOffset::Postgres(PostgresOffset::parse_debezium_offset(
277                offset,
278            )?))
279        })
280    }
281
282    #[try_stream(boxed, ok = OwnedRow, error = ConnectorError)]
283    async fn snapshot_read_inner(
284        &self,
285        table_name: SchemaTableName,
286        start_pk_row: Option<OwnedRow>,
287        primary_keys: Vec<String>,
288        scan_limit: u32,
289    ) {
290        let order_key = Self::get_order_key(&primary_keys);
291        let client = self.client.lock().await;
292        client.execute("set time zone '+00:00'", &[]).await?;
293
294        let stream = match start_pk_row {
295            Some(ref pk_row) => {
296                // prepare the scan statement, since we may need to convert the RW data type to postgres data type
297                // e.g. varchar to uuid
298                let prepared_scan_stmt = {
299                    let primary_keys = self
300                        .pk_indices
301                        .iter()
302                        .map(|i| self.rw_schema.fields[*i].name.clone())
303                        .collect_vec();
304
305                    let order_key = Self::get_order_key(&primary_keys);
306                    let scan_sql = format!(
307                        "SELECT {} FROM {} WHERE {} ORDER BY {} LIMIT {scan_limit}",
308                        self.field_names,
309                        Self::get_normalized_table_name(&table_name),
310                        Self::filter_expression(&primary_keys),
311                        order_key,
312                    );
313                    client.prepare(&scan_sql).await?
314                };
315
316                let params: Vec<Option<ScalarAdapter>> = pk_row
317                    .iter()
318                    .zip_eq_fast(prepared_scan_stmt.params())
319                    .map(|(datum, ty)| {
320                        datum
321                            .map(|scalar| ScalarAdapter::from_scalar(scalar, ty))
322                            .transpose()
323                    })
324                    .try_collect()?;
325
326                client.query_raw(&prepared_scan_stmt, &params).await?
327            }
328            None => {
329                let sql = format!(
330                    "SELECT {} FROM {} ORDER BY {} LIMIT {scan_limit}",
331                    self.field_names,
332                    Self::get_normalized_table_name(&table_name),
333                    order_key,
334                );
335                let params: Vec<Option<ScalarAdapter>> = vec![];
336                client.query_raw(&sql, &params).await?
337            }
338        };
339
340        let row_stream = stream.map(|row| {
341            let row = row?;
342            Ok::<_, crate::error::ConnectorError>(postgres_row_to_owned_row(row, &self.rw_schema))
343        });
344
345        pin_mut!(row_stream);
346        #[for_await]
347        for row in row_stream {
348            let row = row?;
349            yield row;
350        }
351    }
352
353    // row filter expression: (v1, v2, v3) > ($1, $2, $3)
354    fn filter_expression(columns: &[String]) -> String {
355        let mut col_expr = String::new();
356        let mut arg_expr = String::new();
357        for (i, column) in columns.iter().enumerate() {
358            if i > 0 {
359                col_expr.push_str(", ");
360                arg_expr.push_str(", ");
361            }
362            col_expr.push_str(&Self::quote_column(column));
363            arg_expr.push_str(format!("${}", i + 1).as_str());
364        }
365        format!("({}) > ({})", col_expr, arg_expr)
366    }
367
368    // row filter expression: (v1, v2, v3) >= ($1, $2, $3) AND (v1, v2, v3) < ($1, $2, $3)
369    fn split_filter_expression(
370        columns: &[String],
371        is_first_split: bool,
372        is_last_split: bool,
373    ) -> String {
374        let mut left_col_expr = String::new();
375        let mut left_arg_expr = String::new();
376        let mut right_col_expr = String::new();
377        let mut right_arg_expr = String::new();
378        let mut c = 1;
379        if !is_first_split {
380            for (i, column) in columns.iter().enumerate() {
381                if i > 0 {
382                    left_col_expr.push_str(", ");
383                    left_arg_expr.push_str(", ");
384                }
385                left_col_expr.push_str(&Self::quote_column(column));
386                left_arg_expr.push_str(format!("${}", c).as_str());
387                c += 1;
388            }
389        }
390        if !is_last_split {
391            for (i, column) in columns.iter().enumerate() {
392                if i > 0 {
393                    right_col_expr.push_str(", ");
394                    right_arg_expr.push_str(", ");
395                }
396                right_col_expr.push_str(&Self::quote_column(column));
397                right_arg_expr.push_str(format!("${}", c).as_str());
398                c += 1;
399            }
400        }
401        if is_first_split && is_last_split {
402            "1 = 1".to_owned()
403        } else if is_first_split {
404            format!("({}) < ({})", right_col_expr, right_arg_expr,)
405        } else if is_last_split {
406            format!("({}) >= ({})", left_col_expr, left_arg_expr,)
407        } else {
408            format!(
409                "({}) >= ({}) AND ({}) < ({})",
410                left_col_expr, left_arg_expr, right_col_expr, right_arg_expr,
411            )
412        }
413    }
414
415    fn get_order_key(primary_keys: &Vec<String>) -> String {
416        primary_keys
417            .iter()
418            .map(|col| Self::quote_column(col))
419            .join(",")
420    }
421
422    fn quote_column(column: &str) -> String {
423        format!("\"{}\"", column)
424    }
425
426    async fn min_and_max(
427        &self,
428        split_column: &Field,
429    ) -> ConnectorResult<Option<(ScalarImpl, ScalarImpl)>> {
430        let sql = format!(
431            "SELECT MIN({}), MAX({}) FROM {}",
432            split_column.name,
433            split_column.name,
434            Self::get_normalized_table_name(&self.schema_table_name),
435        );
436        let client = self.client.lock().await;
437        let rows = client.query(&sql, &[]).await?;
438        if rows.is_empty() {
439            Ok(None)
440        } else {
441            let row = &rows[0];
442            let min =
443                postgres_cell_to_scalar_impl(row, &split_column.data_type, 0, &split_column.name);
444            let max =
445                postgres_cell_to_scalar_impl(row, &split_column.data_type, 1, &split_column.name);
446            match (min, max) {
447                (Some(min), Some(max)) => Ok(Some((min, max))),
448                _ => Ok(None),
449            }
450        }
451    }
452
453    async fn next_split_right_bound_exclusive(
454        &self,
455        left_value: &ScalarImpl,
456        max_value: &ScalarImpl,
457        max_split_size: u64,
458        split_column: &Field,
459    ) -> ConnectorResult<Option<Datum>> {
460        let sql = format!(
461            "WITH t as (SELECT {} FROM {} WHERE {} >= $1 ORDER BY {} ASC LIMIT {}) SELECT CASE WHEN MAX({}) < $2 THEN MAX({}) ELSE NULL END FROM t",
462            Self::quote_column(&split_column.name),
463            Self::get_normalized_table_name(&self.schema_table_name),
464            Self::quote_column(&split_column.name),
465            Self::quote_column(&split_column.name),
466            max_split_size,
467            Self::quote_column(&split_column.name),
468            Self::quote_column(&split_column.name),
469        );
470        let client = self.client.lock().await;
471        let prepared_stmt = client.prepare(&sql).await?;
472        let params: Vec<Option<ScalarAdapter>> = vec![
473            Some(ScalarAdapter::from_scalar(
474                left_value.as_scalar_ref_impl(),
475                &prepared_stmt.params()[0],
476            )?),
477            Some(ScalarAdapter::from_scalar(
478                max_value.as_scalar_ref_impl(),
479                &prepared_stmt.params()[1],
480            )?),
481        ];
482        let stream = client.query_raw(&prepared_stmt, &params).await?;
483        let datum_stream = stream.map(|row| {
484            let row = row?;
485            Ok::<_, ConnectorError>(postgres_cell_to_scalar_impl(
486                &row,
487                &split_column.data_type,
488                0,
489                &split_column.name,
490            ))
491        });
492        pin_mut!(datum_stream);
493        #[for_await]
494        for datum in datum_stream {
495            let right = datum?;
496            return Ok(Some(right.to_owned_datum()));
497        }
498        Ok(None)
499    }
500
501    async fn next_greater_bound(
502        &self,
503        start_offset: &ScalarImpl,
504        max_value: &ScalarImpl,
505        split_column: &Field,
506    ) -> ConnectorResult<Option<Datum>> {
507        let sql = format!(
508            "SELECT MIN({}) FROM {} WHERE {} > $1 AND {} <$2",
509            Self::quote_column(&split_column.name),
510            Self::get_normalized_table_name(&self.schema_table_name),
511            Self::quote_column(&split_column.name),
512            Self::quote_column(&split_column.name),
513        );
514        let client = self.client.lock().await;
515        let prepared_stmt = client.prepare(&sql).await?;
516        let params: Vec<Option<ScalarAdapter>> = vec![
517            Some(ScalarAdapter::from_scalar(
518                start_offset.as_scalar_ref_impl(),
519                &prepared_stmt.params()[0],
520            )?),
521            Some(ScalarAdapter::from_scalar(
522                max_value.as_scalar_ref_impl(),
523                &prepared_stmt.params()[1],
524            )?),
525        ];
526        let stream = client.query_raw(&prepared_stmt, &params).await?;
527        let datum_stream = stream.map(|row| {
528            let row = row?;
529            Ok::<_, ConnectorError>(postgres_cell_to_scalar_impl(
530                &row,
531                &split_column.data_type,
532                0,
533                &split_column.name,
534            ))
535        });
536        pin_mut!(datum_stream);
537        #[for_await]
538        for datum in datum_stream {
539            let right = datum?;
540            return Ok(Some(right));
541        }
542        Ok(None)
543    }
544
545    #[try_stream(boxed, ok = OwnedRow, error = ConnectorError)]
546    async fn split_snapshot_read_inner(
547        &self,
548        table_name: SchemaTableName,
549        left: OwnedRow,
550        right: OwnedRow,
551        split_columns: Vec<Field>,
552    ) {
553        assert_eq!(
554            split_columns.len(),
555            1,
556            "multiple split columns is not supported yet"
557        );
558        assert_eq!(left.len(), 1, "multiple split columns is not supported yet");
559        assert_eq!(
560            right.len(),
561            1,
562            "multiple split columns is not supported yet"
563        );
564        let is_first_split = left[0].is_none();
565        let is_last_split = right[0].is_none();
566        let split_column_names = split_columns.iter().map(|c| c.name.clone()).collect_vec();
567        let client = self.client.lock().await;
568        client.execute("set time zone '+00:00'", &[]).await?;
569        // prepare the scan statement, since we may need to convert the RW data type to postgres data type
570        // e.g. varchar to uuid
571        let prepared_scan_stmt = {
572            let scan_sql = format!(
573                "SELECT {} FROM {} WHERE {}",
574                self.field_names,
575                Self::get_normalized_table_name(&table_name),
576                Self::split_filter_expression(&split_column_names, is_first_split, is_last_split),
577            );
578            client.prepare(&scan_sql).await?
579        };
580
581        let mut params: Vec<Option<ScalarAdapter>> = vec![];
582        if !is_first_split {
583            let left_params: Vec<Option<ScalarAdapter>> = left
584                .iter()
585                .zip_eq_fast(prepared_scan_stmt.params().iter().take(left.len()))
586                .map(|(datum, ty)| {
587                    datum
588                        .map(|scalar| ScalarAdapter::from_scalar(scalar, ty))
589                        .transpose()
590                })
591                .try_collect()?;
592            params.extend(left_params);
593        }
594        if !is_last_split {
595            let right_params: Vec<Option<ScalarAdapter>> = right
596                .iter()
597                .zip_eq_fast(prepared_scan_stmt.params().iter().skip(params.len()))
598                .map(|(datum, ty)| {
599                    datum
600                        .map(|scalar| ScalarAdapter::from_scalar(scalar, ty))
601                        .transpose()
602                })
603                .try_collect()?;
604            params.extend(right_params);
605        }
606
607        let stream = client.query_raw(&prepared_scan_stmt, &params).await?;
608        let row_stream = stream.map(|row| {
609            let row = row?;
610            Ok::<_, crate::error::ConnectorError>(postgres_row_to_owned_row(row, &self.rw_schema))
611        });
612
613        pin_mut!(row_stream);
614        #[for_await]
615        for row in row_stream {
616            let row = row?;
617            yield row;
618        }
619    }
620
621    #[try_stream(boxed, ok = CdcTableSnapshotSplit, error = ConnectorError)]
622    async fn as_uneven_splits(&self, options: CdcTableSnapshotSplitOption) {
623        let split_column = self.split_column(&options);
624        let mut split_id = CDC_TABLE_SPLIT_ID_START;
625        let Some((min_value, max_value)) = self.min_and_max(&split_column).await? else {
626            let left_bound_row = OwnedRow::new(vec![None]);
627            let right_bound_row = OwnedRow::new(vec![None]);
628            let split = CdcTableSnapshotSplit {
629                split_id,
630                left_bound_inclusive: left_bound_row,
631                right_bound_exclusive: right_bound_row,
632            };
633            yield split;
634            return Ok(());
635        };
636        // left bound will never be NULL value.
637        let mut next_left_bound_inclusive = min_value.clone();
638        loop {
639            let left_bound_inclusive: Datum = if next_left_bound_inclusive == min_value {
640                None
641            } else {
642                Some(next_left_bound_inclusive.clone())
643            };
644            let right_bound_exclusive;
645            let mut next_right = self
646                .next_split_right_bound_exclusive(
647                    &next_left_bound_inclusive,
648                    &max_value,
649                    options.backfill_num_rows_per_split,
650                    &split_column,
651                )
652                .await?;
653            if let Some(Some(ref inner)) = next_right
654                && *inner == next_left_bound_inclusive
655            {
656                next_right = self
657                    .next_greater_bound(&next_left_bound_inclusive, &max_value, &split_column)
658                    .await?;
659            }
660            if let Some(next_right) = next_right {
661                match next_right {
662                    None => {
663                        // NULL found.
664                        right_bound_exclusive = None;
665                    }
666                    Some(next_right) => {
667                        next_left_bound_inclusive = next_right.clone();
668                        right_bound_exclusive = Some(next_right);
669                    }
670                }
671            } else {
672                // Not found.
673                right_bound_exclusive = None;
674            };
675            let is_completed = right_bound_exclusive.is_none();
676            if is_completed && left_bound_inclusive.is_none() {
677                assert_eq!(split_id, 1);
678            }
679            tracing::info!(
680                split_id,
681                ?left_bound_inclusive,
682                ?right_bound_exclusive,
683                "New CDC table snapshot split."
684            );
685            let left_bound_row = OwnedRow::new(vec![left_bound_inclusive]);
686            let right_bound_row = OwnedRow::new(vec![right_bound_exclusive]);
687            let split = CdcTableSnapshotSplit {
688                split_id,
689                left_bound_inclusive: left_bound_row,
690                right_bound_exclusive: right_bound_row,
691            };
692            try_increase_split_id(&mut split_id)?;
693            yield split;
694            if is_completed {
695                break;
696            }
697        }
698    }
699
700    #[try_stream(boxed, ok = CdcTableSnapshotSplit, error = ConnectorError)]
701    async fn as_even_splits(&self, options: CdcTableSnapshotSplitOption) {
702        let split_column = self.split_column(&options);
703        let mut split_id = 1;
704        let Some((min_value, max_value)) = self.min_and_max(&split_column).await? else {
705            let left_bound_row = OwnedRow::new(vec![None]);
706            let right_bound_row = OwnedRow::new(vec![None]);
707            let split = CdcTableSnapshotSplit {
708                split_id,
709                left_bound_inclusive: left_bound_row,
710                right_bound_exclusive: right_bound_row,
711            };
712            yield split;
713            return Ok(());
714        };
715        let min_value = min_value.as_integral();
716        let max_value = max_value.as_integral();
717        let saturated_split_max_size = options
718            .backfill_num_rows_per_split
719            .try_into()
720            .unwrap_or(i64::MAX);
721        let mut left = None;
722        let mut right = Some(min_value.saturating_add(saturated_split_max_size));
723        loop {
724            let mut is_completed = false;
725            if right.as_ref().map(|r| *r >= max_value).unwrap_or(true) {
726                right = None;
727                is_completed = true;
728            }
729            let split = CdcTableSnapshotSplit {
730                split_id,
731                left_bound_inclusive: OwnedRow::new(vec![
732                    left.map(|l| to_int_scalar(l, &split_column.data_type)),
733                ]),
734                right_bound_exclusive: OwnedRow::new(vec![
735                    right.map(|r| to_int_scalar(r, &split_column.data_type)),
736                ]),
737            };
738            try_increase_split_id(&mut split_id)?;
739            yield split;
740            if is_completed {
741                break;
742            }
743            left = right;
744            right = left.map(|l| l.saturating_add(saturated_split_max_size));
745        }
746    }
747
748    fn split_column(&self, options: &CdcTableSnapshotSplitOption) -> Field {
749        self.rw_schema.fields[self.pk_indices[options.backfill_split_pk_column_index as usize]]
750            .clone()
751    }
752}
753
754fn to_int_scalar(i: i64, data_type: &DataType) -> ScalarImpl {
755    match data_type {
756        DataType::Int16 => ScalarImpl::Int16(i.try_into().unwrap()),
757        DataType::Int32 => ScalarImpl::Int32(i.try_into().unwrap()),
758        DataType::Int64 => ScalarImpl::Int64(i),
759        _ => {
760            panic!("Can't convert int {} to ScalarImpl::{}", i, data_type)
761        }
762    }
763}
764
765fn try_increase_split_id(split_id: &mut i64) -> ConnectorResult<()> {
766    match split_id.checked_add(1) {
767        Some(s) => {
768            *split_id = s;
769            Ok(())
770        }
771        None => Err(anyhow::anyhow!("too many CDC snapshot splits").into()),
772    }
773}
774
775/// Use the first column of primary keys to split table.
776fn is_supported_even_split_data_type(data_type: &DataType) -> bool {
777    matches!(
778        data_type,
779        DataType::Int16 | DataType::Int32 | DataType::Int64
780    )
781}
782
783pub fn type_name_to_pg_type(ty_name: &str) -> Option<PgType> {
784    let ty_name_lower = ty_name.to_lowercase();
785    // Handle array types (prefixed with _)
786    if let Some(base_type) = ty_name_lower.strip_prefix('_') {
787        match base_type {
788            "int2" => Some(PgType::INT2_ARRAY),
789            "int4" => Some(PgType::INT4_ARRAY),
790            "int8" => Some(PgType::INT8_ARRAY),
791            "bit" => Some(PgType::BIT_ARRAY),
792            "float4" => Some(PgType::FLOAT4_ARRAY),
793            "float8" => Some(PgType::FLOAT8_ARRAY),
794            "numeric" => Some(PgType::NUMERIC_ARRAY),
795            "bool" => Some(PgType::BOOL_ARRAY),
796            "xml" | "macaddr" | "macaddr8" | "cidr" | "inet" | "int4range" | "int8range"
797            | "numrange" | "tsrange" | "tstzrange" | "daterange" | "citext" => {
798                Some(PgType::VARCHAR_ARRAY)
799            }
800            "varchar" => Some(PgType::VARCHAR_ARRAY),
801            "text" => Some(PgType::TEXT_ARRAY),
802            "bytea" => Some(PgType::BYTEA_ARRAY),
803            "geometry" => Some(PgType::BYTEA_ARRAY), // PostGIS geometry array
804            "date" => Some(PgType::DATE_ARRAY),
805            "time" => Some(PgType::TIME_ARRAY),
806            "timetz" => Some(PgType::TIMETZ_ARRAY),
807            "timestamp" => Some(PgType::TIMESTAMP_ARRAY),
808            "timestamptz" => Some(PgType::TIMESTAMPTZ_ARRAY),
809            "interval" => Some(PgType::INTERVAL_ARRAY),
810            "json" => Some(PgType::JSON_ARRAY),
811            "jsonb" => Some(PgType::JSONB_ARRAY),
812            "uuid" => Some(PgType::UUID_ARRAY),
813            "point" => Some(PgType::POINT_ARRAY),
814            "oid" => Some(PgType::OID_ARRAY),
815            "money" => Some(PgType::MONEY_ARRAY),
816            _ => None,
817        }
818    } else {
819        // Handle non-array types
820        match ty_name_lower.as_str() {
821            "int2" => Some(PgType::INT2),
822            "bit" => Some(PgType::BIT),
823            "int" | "int4" => Some(PgType::INT4),
824            "int8" => Some(PgType::INT8),
825            "float4" => Some(PgType::FLOAT4),
826            "float8" => Some(PgType::FLOAT8),
827            "numeric" => Some(PgType::NUMERIC),
828            "money" => Some(PgType::MONEY),
829            "boolean" | "bool" => Some(PgType::BOOL),
830            "inet" | "xml" | "varchar" | "character varying" | "int4range" | "int8range"
831            | "numrange" | "tsrange" | "tstzrange" | "daterange" | "macaddr" | "macaddr8"
832            | "cidr" => Some(PgType::VARCHAR),
833            "char" | "character" | "bpchar" => Some(PgType::BPCHAR),
834            "citext" | "text" => Some(PgType::TEXT),
835            "bytea" => Some(PgType::BYTEA),
836            "geometry" => Some(PgType::BYTEA), // PostGIS geometry type
837            "date" => Some(PgType::DATE),
838            "time" => Some(PgType::TIME),
839            "timetz" => Some(PgType::TIMETZ),
840            "timestamp" => Some(PgType::TIMESTAMP),
841            "timestamptz" => Some(PgType::TIMESTAMPTZ),
842            "interval" => Some(PgType::INTERVAL),
843            "json" => Some(PgType::JSON),
844            "jsonb" => Some(PgType::JSONB),
845            "uuid" => Some(PgType::UUID),
846            "point" => Some(PgType::POINT),
847            "oid" => Some(PgType::OID),
848            _ => None,
849        }
850    }
851}
852
853pub fn pg_type_to_rw_type(pg_type: &PgType) -> ConnectorResult<DataType> {
854    let data_type = match *pg_type {
855        PgType::BOOL => DataType::Boolean,
856        PgType::BIT => DataType::Boolean,
857        PgType::INT2 => DataType::Int16,
858        PgType::INT4 => DataType::Int32,
859        PgType::INT8 => DataType::Int64,
860        PgType::FLOAT4 => DataType::Float32,
861        PgType::FLOAT8 => DataType::Float64,
862        PgType::NUMERIC | PgType::MONEY => DataType::Decimal,
863        PgType::DATE => DataType::Date,
864        PgType::TIME => DataType::Time,
865        PgType::TIMETZ => DataType::Time,
866        PgType::POINT => DataType::Struct(risingwave_common::types::StructType::new(vec![
867            ("x", DataType::Float32),
868            ("y", DataType::Float32),
869        ])),
870        PgType::TIMESTAMP => DataType::Timestamp,
871        PgType::TIMESTAMPTZ => DataType::Timestamptz,
872        PgType::INTERVAL => DataType::Interval,
873        PgType::VARCHAR | PgType::TEXT | PgType::BPCHAR | PgType::UUID => DataType::Varchar,
874        PgType::BYTEA => DataType::Bytea,
875        PgType::JSON | PgType::JSONB => DataType::Jsonb,
876        // Array types
877        PgType::BOOL_ARRAY => DataType::Boolean.list(),
878        PgType::BIT_ARRAY => DataType::Boolean.list(),
879        PgType::INT2_ARRAY => DataType::Int16.list(),
880        PgType::INT4_ARRAY => DataType::Int32.list(),
881        PgType::INT8_ARRAY => DataType::Int64.list(),
882        PgType::FLOAT4_ARRAY => DataType::Float32.list(),
883        PgType::FLOAT8_ARRAY => DataType::Float64.list(),
884        PgType::NUMERIC_ARRAY => DataType::Decimal.list(),
885        PgType::VARCHAR_ARRAY => DataType::Varchar.list(),
886        PgType::TEXT_ARRAY => DataType::Varchar.list(),
887        PgType::BYTEA_ARRAY => DataType::Bytea.list(),
888        PgType::DATE_ARRAY => DataType::Date.list(),
889        PgType::TIME_ARRAY => DataType::Time.list(),
890        PgType::TIMESTAMP_ARRAY => DataType::Timestamp.list(),
891        PgType::TIMESTAMPTZ_ARRAY => DataType::Timestamptz.list(),
892        PgType::INTERVAL_ARRAY => DataType::Interval.list(),
893        PgType::JSON_ARRAY => DataType::Jsonb.list(),
894        PgType::JSONB_ARRAY => DataType::Jsonb.list(),
895        PgType::UUID_ARRAY => DataType::Varchar.list(),
896        PgType::OID => DataType::Int64,
897        PgType::OID_ARRAY => DataType::Int64.list(),
898        PgType::MONEY_ARRAY => DataType::Decimal.list(),
899        PgType::POINT_ARRAY => {
900            DataType::list(DataType::Struct(risingwave_common::types::StructType::new(
901                vec![("x", DataType::Float32), ("y", DataType::Float32)],
902            )))
903        }
904        _ => {
905            return Err(anyhow::anyhow!("unsupported postgres type: {}", pg_type).into());
906        }
907    };
908    Ok(data_type)
909}
910
911#[cfg(test)]
912mod tests {
913    use std::cmp::Ordering;
914    use std::collections::HashMap;
915
916    use futures::pin_mut;
917    use futures_async_stream::for_await;
918    use maplit::{convert_args, hashmap};
919    use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
920    use risingwave_common::row::OwnedRow;
921    use risingwave_common::types::{DataType, ScalarImpl};
922
923    use crate::connector_common::PostgresExternalTable;
924    use crate::source::cdc::external::postgres::{PostgresExternalTableReader, PostgresOffset};
925    use crate::source::cdc::external::{ExternalTableConfig, ExternalTableReader, SchemaTableName};
926
927    #[ignore]
928    #[tokio::test]
929    async fn test_postgres_schema() {
930        let config = ExternalTableConfig {
931            connector: "postgres-cdc".to_owned(),
932            host: "localhost".to_owned(),
933            port: "8432".to_owned(),
934            username: "myuser".to_owned(),
935            password: "123456".to_owned(),
936            database: "mydb".to_owned(),
937            schema: "public".to_owned(),
938            table: "mytest".to_owned(),
939            ssl_mode: Default::default(),
940            ssl_root_cert: None,
941            encrypt: "false".to_owned(),
942        };
943
944        let table = PostgresExternalTable::connect(
945            &config.username,
946            &config.password,
947            &config.host,
948            config.port.parse::<u16>().unwrap(),
949            &config.database,
950            &config.schema,
951            &config.table,
952            &config.ssl_mode,
953            &config.ssl_root_cert,
954            false,
955        )
956        .await
957        .unwrap();
958
959        println!("columns: {:?}", &table.column_descs());
960        println!("primary keys: {:?}", &table.pk_names());
961    }
962
963    #[test]
964    fn test_postgres_offset() {
965        let off1 = PostgresOffset {
966            txid: 4,
967            lsn: 2,
968            ..Default::default()
969        };
970        let off2 = PostgresOffset {
971            txid: 1,
972            lsn: 3,
973            ..Default::default()
974        };
975        let off3 = PostgresOffset {
976            txid: 5,
977            lsn: 1,
978            ..Default::default()
979        };
980
981        assert!(off1 < off2);
982        assert!(off3 < off1);
983        assert!(off2 > off3);
984    }
985
986    #[test]
987    fn test_postgres_offset_partial_ord_with_lsn_commit() {
988        // Test comparison with both lsn_commit and lsn_proc fields
989        let off1 = PostgresOffset {
990            txid: 1,
991            lsn: 100,
992            lsn_commit: Some(200),
993            lsn_proc: Some(150),
994        };
995        let off2 = PostgresOffset {
996            txid: 2,
997            lsn: 300,
998            lsn_commit: Some(250),
999            lsn_proc: Some(200),
1000        };
1001
1002        // Should compare using lsn_commit first when both have both fields
1003        assert!(off1 < off2);
1004
1005        // Test with same lsn_commit but different lsn_proc
1006        let off3 = PostgresOffset {
1007            txid: 3,
1008            lsn: 500,
1009            lsn_commit: Some(200), // same as off1
1010            lsn_proc: Some(160),   // higher than off1
1011        };
1012
1013        // Should compare lsn_proc when lsn_commit is equal
1014        assert!(off1 < off3);
1015
1016        // Test with missing lsn_proc - should fall back to lsn comparison
1017        let off4 = PostgresOffset {
1018            txid: 4,
1019            lsn: 400,
1020            lsn_commit: Some(100), // lower than off1's lsn_commit
1021            lsn_proc: None,        // missing lsn_proc
1022        };
1023
1024        // Should fall back to lsn comparison (off1.lsn=100 < off4.lsn=400)
1025        assert!(off1 < off4);
1026
1027        // Test with missing lsn_commit - should fall back to lsn comparison
1028        let off5 = PostgresOffset {
1029            txid: 5,
1030            lsn: 50,             // lower than off1.lsn
1031            lsn_commit: None,    // missing lsn_commit
1032            lsn_proc: Some(300), // higher than off1's lsn_proc
1033        };
1034
1035        // Should fall back to lsn comparison (off5.lsn=50 < off1.lsn=100)
1036        assert!(off5 < off1);
1037
1038        // Additional test cases: equal lsn_commit values with different lsn_proc
1039        let off6 = PostgresOffset {
1040            txid: 6,
1041            lsn: 600,
1042            lsn_commit: Some(500),
1043            lsn_proc: Some(300),
1044        };
1045        let off7 = PostgresOffset {
1046            txid: 7,
1047            lsn: 700,
1048            lsn_commit: Some(500), // same as off6
1049            lsn_proc: Some(400),   // higher than off6
1050        };
1051
1052        // Should compare lsn_proc since lsn_commit is equal
1053        assert!(off6 < off7);
1054
1055        // Test reverse order
1056        let off8 = PostgresOffset {
1057            txid: 8,
1058            lsn: 800,
1059            lsn_commit: Some(500), // same as others
1060            lsn_proc: Some(200),   // lower than off6
1061        };
1062
1063        assert!(off8 < off6);
1064        assert!(off8 < off7);
1065
1066        // Test equal lsn_commit and lsn_proc
1067        let off9 = PostgresOffset {
1068            txid: 9,
1069            lsn: 900,
1070            lsn_commit: Some(500), // same as off6
1071            lsn_proc: Some(300),   // same as off6
1072        };
1073
1074        // Should be equal
1075        assert_eq!(off6.partial_cmp(&off9), Some(Ordering::Equal));
1076    }
1077
1078    #[test]
1079    fn test_debezium_offset_parsing() {
1080        // Test parsing with all required fields present
1081        let debezium_offset_with_fields = r#"{
1082            "sourcePartition": {"server": "RW_CDC_1004"},
1083            "sourceOffset": {
1084                "last_snapshot_record": false,
1085                "lsn": 29973552,
1086                "txId": 1046,
1087                "ts_usec": 1670826189008456,
1088                "snapshot": true,
1089                "lsn_commit": 29973600,
1090                "lsn_proc": 29973580
1091            },
1092            "isHeartbeat": false
1093        }"#;
1094
1095        let offset = PostgresOffset::parse_debezium_offset(debezium_offset_with_fields).unwrap();
1096        assert_eq!(offset.txid, 1046);
1097        assert_eq!(offset.lsn, 29973552);
1098        assert_eq!(offset.lsn_commit, Some(29973600));
1099        assert_eq!(offset.lsn_proc, Some(29973580));
1100
1101        // Test parsing should fail when required fields are missing
1102        let debezium_offset_missing_fields = r#"{
1103            "sourcePartition": {"server": "RW_CDC_1004"},
1104            "sourceOffset": {
1105                "last_snapshot_record": false,
1106                "lsn": 29973552,
1107                "txId": 1046,
1108                "ts_usec": 1670826189008456,
1109                "snapshot": true
1110            },
1111            "isHeartbeat": false
1112        }"#;
1113
1114        let result = PostgresOffset::parse_debezium_offset(debezium_offset_missing_fields);
1115        assert!(result.is_err());
1116        let error_msg = result.unwrap_err().to_string();
1117        assert!(error_msg.contains("invalid postgres lsn_proc"));
1118    }
1119
1120    #[test]
1121    fn test_filter_expression() {
1122        let cols = vec!["v1".to_owned()];
1123        let expr = PostgresExternalTableReader::filter_expression(&cols);
1124        assert_eq!(expr, "(\"v1\") > ($1)");
1125
1126        let cols = vec!["v1".to_owned(), "v2".to_owned()];
1127        let expr = PostgresExternalTableReader::filter_expression(&cols);
1128        assert_eq!(expr, "(\"v1\", \"v2\") > ($1, $2)");
1129
1130        let cols = vec!["v1".to_owned(), "v2".to_owned(), "v3".to_owned()];
1131        let expr = PostgresExternalTableReader::filter_expression(&cols);
1132        assert_eq!(expr, "(\"v1\", \"v2\", \"v3\") > ($1, $2, $3)");
1133    }
1134
1135    #[test]
1136    fn test_split_filter_expression() {
1137        let cols = vec!["v1".to_owned()];
1138        let expr = PostgresExternalTableReader::split_filter_expression(&cols, true, true);
1139        assert_eq!(expr, "1 = 1");
1140
1141        let expr = PostgresExternalTableReader::split_filter_expression(&cols, true, false);
1142        assert_eq!(expr, "(\"v1\") < ($1)");
1143
1144        let expr = PostgresExternalTableReader::split_filter_expression(&cols, false, true);
1145        assert_eq!(expr, "(\"v1\") >= ($1)");
1146
1147        let expr = PostgresExternalTableReader::split_filter_expression(&cols, false, false);
1148        assert_eq!(expr, "(\"v1\") >= ($1) AND (\"v1\") < ($2)");
1149    }
1150
1151    // manual test
1152    #[ignore]
1153    #[tokio::test]
1154    async fn test_pg_table_reader() {
1155        let columns = [
1156            ColumnDesc::named("v1", ColumnId::new(1), DataType::Int32),
1157            ColumnDesc::named("v2", ColumnId::new(2), DataType::Varchar),
1158            ColumnDesc::named("v3", ColumnId::new(3), DataType::Decimal),
1159            ColumnDesc::named("v4", ColumnId::new(4), DataType::Date),
1160        ];
1161        let rw_schema = Schema {
1162            fields: columns.iter().map(Field::from).collect(),
1163        };
1164
1165        let props: HashMap<String, String> = convert_args!(hashmap!(
1166                "hostname" => "localhost",
1167                "port" => "8432",
1168                "username" => "myuser",
1169                "password" => "123456",
1170                "database.name" => "mydb",
1171                "schema.name" => "public",
1172                "table.name" => "t1"));
1173
1174        let config =
1175            serde_json::from_value::<ExternalTableConfig>(serde_json::to_value(props).unwrap())
1176                .unwrap();
1177        let schema_table_name = SchemaTableName {
1178            schema_name: "public".to_owned(),
1179            table_name: "t1".to_owned(),
1180        };
1181        let reader = PostgresExternalTableReader::new(
1182            config,
1183            rw_schema,
1184            vec![0, 1],
1185            schema_table_name.clone(),
1186        )
1187        .await
1188        .unwrap();
1189
1190        let offset = reader.current_cdc_offset().await.unwrap();
1191        println!("CdcOffset: {:?}", offset);
1192
1193        let start_pk = OwnedRow::new(vec![Some(ScalarImpl::from(3)), Some(ScalarImpl::from("c"))]);
1194        let stream = reader.snapshot_read(
1195            schema_table_name,
1196            Some(start_pk),
1197            vec!["v1".to_owned(), "v2".to_owned()],
1198            1000,
1199        );
1200
1201        pin_mut!(stream);
1202        #[for_await]
1203        for row in stream {
1204            println!("OwnedRow: {:?}", row);
1205        }
1206    }
1207}