risingwave_connector/sink/
postgres.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashSet};
16use std::sync::Arc;
17
18use anyhow::{Context, anyhow};
19use async_trait::async_trait;
20use futures::StreamExt;
21use futures::stream::FuturesUnordered;
22use itertools::Itertools;
23use phf::phf_set;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::row::{Row, RowExt};
27use serde::Deserialize;
28use serde_with::{DisplayFromStr, serde_as};
29use simd_json::prelude::ArrayTrait;
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Type as PgType;
32
33use super::{
34    LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
35};
36use crate::connector_common::{PostgresExternalTable, SslMode, create_pg_client};
37use crate::enforce_secret::EnforceSecret;
38use crate::parser::scalar_adapter::{ScalarAdapter, validate_pg_type_to_rw_type};
39use crate::sink::log_store::{LogStoreReadItem, TruncateOffset};
40use crate::sink::{Result, Sink, SinkParam, SinkWriterParam};
41
42pub const POSTGRES_SINK: &str = "postgres";
43
44#[serde_as]
45#[derive(Clone, Debug, Deserialize)]
46pub struct PostgresConfig {
47    pub host: String,
48    #[serde_as(as = "DisplayFromStr")]
49    pub port: u16,
50    pub user: String,
51    pub password: String,
52    pub database: String,
53    pub table: String,
54    #[serde(default = "default_schema")]
55    pub schema: String,
56    #[serde(default = "Default::default")]
57    pub ssl_mode: SslMode,
58    #[serde(rename = "ssl.root.cert")]
59    pub ssl_root_cert: Option<String>,
60    #[serde(default = "default_max_batch_rows")]
61    #[serde_as(as = "DisplayFromStr")]
62    pub max_batch_rows: usize,
63    pub r#type: String, // accept "append-only" or "upsert"
64}
65
66impl EnforceSecret for PostgresConfig {
67    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf_set! {
68        "password", "ssl.root.cert"
69    };
70}
71
72fn default_max_batch_rows() -> usize {
73    1024
74}
75
76fn default_schema() -> String {
77    "public".to_owned()
78}
79
80impl PostgresConfig {
81    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
82        let config =
83            serde_json::from_value::<PostgresConfig>(serde_json::to_value(properties).unwrap())
84                .map_err(|e| SinkError::Config(anyhow!(e)))?;
85        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
86            return Err(SinkError::Config(anyhow!(
87                "`{}` must be {}, or {}",
88                SINK_TYPE_OPTION,
89                SINK_TYPE_APPEND_ONLY,
90                SINK_TYPE_UPSERT
91            )));
92        }
93        Ok(config)
94    }
95}
96
97#[derive(Debug)]
98pub struct PostgresSink {
99    pub config: PostgresConfig,
100    schema: Schema,
101    pk_indices: Vec<usize>,
102    is_append_only: bool,
103}
104
105impl PostgresSink {
106    pub fn new(
107        config: PostgresConfig,
108        schema: Schema,
109        pk_indices: Vec<usize>,
110        is_append_only: bool,
111    ) -> Result<Self> {
112        Ok(Self {
113            config,
114            schema,
115            pk_indices,
116            is_append_only,
117        })
118    }
119}
120
121impl EnforceSecret for PostgresSink {
122    fn enforce_secret<'a>(
123        prop_iter: impl Iterator<Item = &'a str>,
124    ) -> crate::error::ConnectorResult<()> {
125        for prop in prop_iter {
126            PostgresConfig::enforce_one(prop)?;
127        }
128        Ok(())
129    }
130}
131
132impl TryFrom<SinkParam> for PostgresSink {
133    type Error = SinkError;
134
135    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
136        let schema = param.schema();
137        let pk_indices = param.downstream_pk_or_empty();
138        let config = PostgresConfig::from_btreemap(param.properties)?;
139        PostgresSink::new(config, schema, pk_indices, param.sink_type.is_append_only())
140    }
141}
142
143impl Sink for PostgresSink {
144    type LogSinker = PostgresSinkWriter;
145
146    const SINK_NAME: &'static str = POSTGRES_SINK;
147
148    async fn validate(&self) -> Result<()> {
149        if !self.is_append_only && self.pk_indices.is_empty() {
150            return Err(SinkError::Config(anyhow!(
151                "Primary key not defined for upsert Postgres sink (please define in `primary_key` field)"
152            )));
153        }
154
155        // Verify our sink schema is compatible with Postgres
156        {
157            let pg_table = PostgresExternalTable::connect(
158                &self.config.user,
159                &self.config.password,
160                &self.config.host,
161                self.config.port,
162                &self.config.database,
163                &self.config.schema,
164                &self.config.table,
165                &self.config.ssl_mode,
166                &self.config.ssl_root_cert,
167                self.is_append_only,
168            )
169            .await
170            .context(format!(
171                "failed to connect to database: {}, schema: {}, table: {}",
172                &self.config.database, &self.config.schema, &self.config.table
173            ))?;
174
175            // Check that names and types match, order of columns doesn't matter.
176            {
177                let pg_columns = pg_table.column_descs();
178                let sink_columns = self.schema.fields();
179                if pg_columns.len() < sink_columns.len() {
180                    return Err(SinkError::Config(anyhow!(
181                        "Column count mismatch: Postgres table has {} columns, but sink schema has {} columns, sink should have less or equal columns to the Postgres table",
182                        pg_columns.len(),
183                        sink_columns.len()
184                    )));
185                }
186
187                let pg_columns_lookup = pg_columns
188                    .iter()
189                    .map(|c| (c.name.clone(), c.data_type.clone()))
190                    .collect::<BTreeMap<_, _>>();
191                for sink_column in sink_columns {
192                    let pg_column = pg_columns_lookup.get(&sink_column.name);
193                    match pg_column {
194                        None => {
195                            return Err(SinkError::Config(anyhow!(
196                                "Column `{}` not found in Postgres table `{}`",
197                                sink_column.name,
198                                self.config.table
199                            )));
200                        }
201                        Some(pg_column) => {
202                            if !validate_pg_type_to_rw_type(pg_column, &sink_column.data_type()) {
203                                return Err(SinkError::Config(anyhow!(
204                                    "Column `{}` in Postgres table `{}` has type `{}`, but sink schema defines it as type `{}`",
205                                    sink_column.name,
206                                    self.config.table,
207                                    pg_column,
208                                    sink_column.data_type()
209                                )));
210                            }
211                        }
212                    }
213                }
214            }
215
216            // check that pk matches
217            {
218                let pg_pk_names = pg_table.pk_names();
219                let sink_pk_names = self
220                    .pk_indices
221                    .iter()
222                    .map(|i| &self.schema.fields()[*i].name)
223                    .collect::<HashSet<_>>();
224                if pg_pk_names.len() != sink_pk_names.len() {
225                    return Err(SinkError::Config(anyhow!(
226                        "Primary key mismatch: Postgres table has primary key on columns {:?}, but sink schema defines primary key on columns {:?}",
227                        pg_pk_names,
228                        sink_pk_names
229                    )));
230                }
231                for name in pg_pk_names {
232                    if !sink_pk_names.contains(name) {
233                        return Err(SinkError::Config(anyhow!(
234                            "Primary key mismatch: Postgres table has primary key on column `{}`, but sink schema does not define it as a primary key",
235                            name
236                        )));
237                    }
238                }
239            }
240        }
241
242        Ok(())
243    }
244
245    async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
246        PostgresSinkWriter::new(
247            self.config.clone(),
248            self.schema.clone(),
249            self.pk_indices.clone(),
250            self.is_append_only,
251        )
252        .await
253    }
254}
255
256pub struct PostgresSinkWriter {
257    is_append_only: bool,
258    client: tokio_postgres::Client,
259    pk_indices: Vec<usize>,
260    pk_types: Vec<PgType>,
261    schema_types: Vec<PgType>,
262    raw_insert_sql: Arc<String>,
263    raw_upsert_sql: Arc<String>,
264    raw_delete_sql: Arc<String>,
265    insert_sql: Arc<tokio_postgres::Statement>,
266    delete_sql: Arc<tokio_postgres::Statement>,
267    upsert_sql: Arc<tokio_postgres::Statement>,
268}
269
270impl PostgresSinkWriter {
271    async fn new(
272        config: PostgresConfig,
273        schema: Schema,
274        pk_indices: Vec<usize>,
275        is_append_only: bool,
276    ) -> Result<Self> {
277        let client = create_pg_client(
278            &config.user,
279            &config.password,
280            &config.host,
281            &config.port.to_string(),
282            &config.database,
283            &config.ssl_mode,
284            &config.ssl_root_cert,
285        )
286        .await?;
287
288        let pk_indices_lookup = pk_indices.iter().copied().collect::<HashSet<_>>();
289
290        // Rewrite schema types for serialization
291        let (pk_types, schema_types) = {
292            let name_to_type = PostgresExternalTable::type_mapping(
293                &config.user,
294                &config.password,
295                &config.host,
296                config.port,
297                &config.database,
298                &config.schema,
299                &config.table,
300                &config.ssl_mode,
301                &config.ssl_root_cert,
302                is_append_only,
303            )
304            .await?;
305            let mut schema_types = Vec::with_capacity(schema.fields.len());
306            let mut pk_types = Vec::with_capacity(pk_indices.len());
307            for (i, field) in schema.fields.iter().enumerate() {
308                let field_name = &field.name;
309                let actual_data_type = name_to_type.get(field_name).map(|t| (*t).clone());
310                let actual_data_type = actual_data_type
311                    .ok_or_else(|| {
312                        SinkError::Config(anyhow!(
313                            "Column `{}` not found in sink schema",
314                            field_name
315                        ))
316                    })?
317                    .clone();
318                if pk_indices_lookup.contains(&i) {
319                    pk_types.push(actual_data_type.clone())
320                }
321                schema_types.push(actual_data_type);
322            }
323            (pk_types, schema_types)
324        };
325
326        let raw_insert_sql = create_insert_sql(&schema, &config.schema, &config.table);
327        let raw_upsert_sql = create_upsert_sql(
328            &schema,
329            &config.schema,
330            &config.table,
331            &pk_indices,
332            &pk_indices_lookup,
333        );
334        let raw_delete_sql = create_delete_sql(&schema, &config.schema, &config.table, &pk_indices);
335
336        let insert_sql = client
337            .prepare(&raw_insert_sql)
338            .await
339            .with_context(|| format!("failed to prepare insert statement: {}", raw_insert_sql))?;
340        let upsert_sql = client
341            .prepare(&raw_upsert_sql)
342            .await
343            .with_context(|| format!("failed to prepare upsert statement: {}", raw_upsert_sql))?;
344        let delete_sql = client
345            .prepare(&raw_delete_sql)
346            .await
347            .with_context(|| format!("failed to prepare delete statement: {}", raw_delete_sql))?;
348
349        let writer = Self {
350            is_append_only,
351            client,
352            pk_indices,
353            pk_types,
354            schema_types,
355            raw_insert_sql: Arc::new(raw_insert_sql),
356            raw_upsert_sql: Arc::new(raw_upsert_sql),
357            raw_delete_sql: Arc::new(raw_delete_sql),
358            insert_sql: Arc::new(insert_sql),
359            delete_sql: Arc::new(delete_sql),
360            upsert_sql: Arc::new(upsert_sql),
361        };
362        Ok(writer)
363    }
364
365    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
366        // https://www.postgresql.org/docs/current/limits.html
367        // We have a limit of 65,535 parameters in a single query, as restricted by the PostgreSQL protocol.
368        if self.is_append_only {
369            self.write_batch_append_only(chunk).await
370        } else {
371            self.write_batch_non_append_only(chunk).await
372        }
373    }
374
375    async fn write_batch_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
376        let transaction = Arc::new(self.client.transaction().await?);
377        let mut insert_futures = FuturesUnordered::new();
378        for (op, row) in chunk.rows() {
379            match op {
380                Op::Insert => {
381                    let pg_row = convert_row_to_pg_row(row, &self.schema_types);
382                    let insert_sql = self.insert_sql.clone();
383                    let raw_insert_sql = self.raw_insert_sql.clone();
384                    let transaction = transaction.clone();
385                    let future = async move {
386                        transaction
387                            .execute_raw(insert_sql.as_ref(), &pg_row)
388                            .await
389                            .with_context(|| {
390                                format!(
391                                    "failed to execute insert statement: {}, parameters: {:?}",
392                                    raw_insert_sql, pg_row
393                                )
394                            })
395                    };
396                    insert_futures.push(future);
397                }
398                _ => {
399                    tracing::error!(
400                        "row ignored, append-only sink should not receive update insert, update delete and delete operations"
401                    );
402                }
403            }
404        }
405
406        while let Some(result) = insert_futures.next().await {
407            result?;
408        }
409        if let Some(transaction) = Arc::into_inner(transaction) {
410            transaction.commit().await?;
411        } else {
412            tracing::error!("transaction lost!");
413        }
414
415        Ok(())
416    }
417
418    async fn write_batch_non_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
419        let transaction = Arc::new(self.client.transaction().await?);
420        let mut delete_futures = FuturesUnordered::new();
421        let mut upsert_futures = FuturesUnordered::new();
422        for (op, row) in chunk.rows() {
423            match op {
424                Op::Delete | Op::UpdateDelete => {
425                    let pg_row =
426                        convert_row_to_pg_row(row.project(&self.pk_indices), &self.pk_types);
427                    let delete_sql = self.delete_sql.clone();
428                    let raw_delete_sql = self.raw_delete_sql.clone();
429                    let transaction = transaction.clone();
430                    let future = async move {
431                        transaction
432                            .execute_raw(delete_sql.as_ref(), &pg_row)
433                            .await
434                            .with_context(|| {
435                                format!(
436                                    "failed to execute delete statement: {}, parameters: {:?}",
437                                    raw_delete_sql, pg_row
438                                )
439                            })
440                    };
441                    delete_futures.push(future);
442                }
443                Op::Insert | Op::UpdateInsert => {
444                    let pg_row = convert_row_to_pg_row(row, &self.schema_types);
445                    let upsert_sql = self.upsert_sql.clone();
446                    let raw_upsert_sql = self.raw_upsert_sql.clone();
447                    let transaction = transaction.clone();
448                    let future = async move {
449                        transaction
450                            .execute_raw(upsert_sql.as_ref(), &pg_row)
451                            .await
452                            .with_context(|| {
453                                format!(
454                                    "failed to execute upsert statement: {}, parameters: {:?}",
455                                    raw_upsert_sql, pg_row
456                                )
457                            })
458                    };
459                    upsert_futures.push(future);
460                }
461            }
462        }
463        while let Some(result) = delete_futures.next().await {
464            result?;
465        }
466        while let Some(result) = upsert_futures.next().await {
467            result?;
468        }
469        if let Some(transaction) = Arc::into_inner(transaction) {
470            transaction.commit().await?;
471        } else {
472            tracing::error!("transaction lost!");
473        }
474        Ok(())
475    }
476}
477
478#[async_trait]
479impl LogSinker for PostgresSinkWriter {
480    async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
481        log_reader.start_from(None).await?;
482        loop {
483            let (epoch, item) = log_reader.next_item().await?;
484            match item {
485                LogStoreReadItem::StreamChunk { chunk, chunk_id } => {
486                    self.write_batch(chunk).await?;
487                    log_reader.truncate(TruncateOffset::Chunk { epoch, chunk_id })?;
488                }
489                LogStoreReadItem::Barrier { .. } => {
490                    log_reader.truncate(TruncateOffset::Barrier { epoch })?;
491                }
492            }
493        }
494    }
495}
496
497fn create_insert_sql(schema: &Schema, schema_name: &str, table_name: &str) -> String {
498    let normalized_table_name = format!(
499        "{}.{}",
500        quote_identifier(schema_name),
501        quote_identifier(table_name)
502    );
503    let number_of_columns = schema.len();
504    let columns: String = schema
505        .fields()
506        .iter()
507        .map(|field| quote_identifier(&field.name))
508        .join(", ");
509    let column_parameters: String = (0..number_of_columns)
510        .map(|i| format!("${}", i + 1))
511        .join(", ");
512    format!("INSERT INTO {normalized_table_name} ({columns}) VALUES ({column_parameters})")
513}
514
515fn create_delete_sql(
516    schema: &Schema,
517    schema_name: &str,
518    table_name: &str,
519    pk_indices: &[usize],
520) -> String {
521    let normalized_table_name = format!(
522        "{}.{}",
523        quote_identifier(schema_name),
524        quote_identifier(table_name)
525    );
526    let pk_indices = if pk_indices.is_empty() {
527        (0..schema.len()).collect_vec()
528    } else {
529        pk_indices.to_vec()
530    };
531    let pk = {
532        let pk_symbols = pk_indices
533            .iter()
534            .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
535            .join(", ");
536        format!("({})", pk_symbols)
537    };
538    let parameters: String = (0..pk_indices.len())
539        .map(|i| format!("${}", i + 1))
540        .join(", ");
541    format!("DELETE FROM {normalized_table_name} WHERE {pk} in (({parameters}))")
542}
543
544fn create_upsert_sql(
545    schema: &Schema,
546    schema_name: &str,
547    table_name: &str,
548    pk_indices: &[usize],
549    pk_indices_lookup: &HashSet<usize>,
550) -> String {
551    let insert_sql = create_insert_sql(schema, schema_name, table_name);
552    if pk_indices.is_empty() {
553        return insert_sql;
554    }
555    let pk_columns = pk_indices
556        .iter()
557        .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
558        .collect_vec()
559        .join(", ");
560    let update_parameters: String = (0..schema.len())
561        .filter(|i| !pk_indices_lookup.contains(i))
562        .map(|i| {
563            let column = quote_identifier(&schema.fields()[i].name);
564            format!("{column} = EXCLUDED.{column}")
565        })
566        .collect_vec()
567        .join(", ");
568    format!("{insert_sql} on conflict ({pk_columns}) do update set {update_parameters}")
569}
570
571/// Quote an identifier for PostgreSQL.
572fn quote_identifier(identifier: &str) -> String {
573    format!("\"{}\"", identifier.replace("\"", "\"\""))
574}
575
576type PgDatum = Option<ScalarAdapter>;
577type PgRow = Vec<PgDatum>;
578
579fn convert_row_to_pg_row(row: impl Row, schema_types: &[PgType]) -> PgRow {
580    let mut buffer = Vec::with_capacity(row.len());
581    for (i, datum_ref) in row.iter().enumerate() {
582        let pg_datum = datum_ref.map(|s| {
583            match ScalarAdapter::from_scalar(s, &schema_types[i]) {
584                Ok(scalar) => Some(scalar),
585                Err(e) => {
586                    tracing::error!(error=%e.as_report(), scalar=?s, "Failed to convert scalar to pg value");
587                    None
588                }
589            }
590        });
591        buffer.push(pg_datum.flatten());
592    }
593    buffer
594}
595
596#[cfg(test)]
597mod tests {
598    use std::fmt::Display;
599
600    use expect_test::{Expect, expect};
601    use risingwave_common::catalog::Field;
602    use risingwave_common::types::DataType;
603
604    use super::*;
605
606    fn check(actual: impl Display, expect: Expect) {
607        let actual = actual.to_string();
608        expect.assert_eq(&actual);
609    }
610
611    #[test]
612    fn test_create_insert_sql() {
613        let schema = Schema::new(vec![
614            Field {
615                data_type: DataType::Int32,
616                name: "a".to_owned(),
617            },
618            Field {
619                data_type: DataType::Int32,
620                name: "b".to_owned(),
621            },
622        ]);
623        let schema_name = "test_schema";
624        let table_name = "test_table";
625        let sql = create_insert_sql(&schema, schema_name, table_name);
626        check(
627            sql,
628            expect![[r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2)"#]],
629        );
630    }
631
632    #[test]
633    fn test_create_delete_sql() {
634        let schema = Schema::new(vec![
635            Field {
636                data_type: DataType::Int32,
637                name: "a".to_owned(),
638            },
639            Field {
640                data_type: DataType::Int32,
641                name: "b".to_owned(),
642            },
643        ]);
644        let schema_name = "test_schema";
645        let table_name = "test_table";
646        let sql = create_delete_sql(&schema, schema_name, table_name, &[1]);
647        check(
648            sql,
649            expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("b") in (($1))"#]],
650        );
651        let table_name = "test_table";
652        let sql = create_delete_sql(&schema, schema_name, table_name, &[0, 1]);
653        check(
654            sql,
655            expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("a", "b") in (($1, $2))"#]],
656        );
657    }
658
659    #[test]
660    fn test_create_upsert_sql() {
661        let schema = Schema::new(vec![
662            Field {
663                data_type: DataType::Int32,
664                name: "a".to_owned(),
665            },
666            Field {
667                data_type: DataType::Int32,
668                name: "b".to_owned(),
669            },
670        ]);
671        let schema_name = "test_schema";
672        let table_name = "test_table";
673        let pk_indices_lookup = HashSet::from_iter([1]);
674        let sql = create_upsert_sql(&schema, schema_name, table_name, &[1], &pk_indices_lookup);
675        check(
676            sql,
677            expect![[
678                r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2) on conflict ("b") do update set "a" = EXCLUDED."a""#
679            ]],
680        );
681    }
682}