risingwave_connector/sink/
postgres.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashSet};
16use std::sync::Arc;
17
18use anyhow::{Context, anyhow};
19use async_trait::async_trait;
20use futures::StreamExt;
21use futures::stream::FuturesUnordered;
22use itertools::Itertools;
23use phf::phf_set;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::row::{Row, RowExt};
27use serde_derive::Deserialize;
28use serde_with::{DisplayFromStr, serde_as};
29use simd_json::prelude::ArrayTrait;
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Type as PgType;
32
33use super::{
34    LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
35};
36use crate::connector_common::{PostgresExternalTable, SslMode, create_pg_client};
37use crate::enforce_secret::EnforceSecret;
38use crate::parser::scalar_adapter::{ScalarAdapter, validate_pg_type_to_rw_type};
39use crate::sink::log_store::{LogStoreReadItem, TruncateOffset};
40use crate::sink::{DummySinkCommitCoordinator, Result, Sink, SinkParam, SinkWriterParam};
41
42pub const POSTGRES_SINK: &str = "postgres";
43
44#[serde_as]
45#[derive(Clone, Debug, Deserialize)]
46pub struct PostgresConfig {
47    pub host: String,
48    #[serde_as(as = "DisplayFromStr")]
49    pub port: u16,
50    pub user: String,
51    pub password: String,
52    pub database: String,
53    pub table: String,
54    #[serde(default = "default_schema")]
55    pub schema: String,
56    #[serde(default = "Default::default")]
57    pub ssl_mode: SslMode,
58    #[serde(rename = "ssl.root.cert")]
59    pub ssl_root_cert: Option<String>,
60    #[serde(default = "default_max_batch_rows")]
61    #[serde_as(as = "DisplayFromStr")]
62    pub max_batch_rows: usize,
63    pub r#type: String, // accept "append-only" or "upsert"
64}
65
66impl EnforceSecret for PostgresConfig {
67    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf_set! {
68        "password", "ssl.root.cert"
69    };
70}
71
72fn default_max_batch_rows() -> usize {
73    1024
74}
75
76fn default_schema() -> String {
77    "public".to_owned()
78}
79
80impl PostgresConfig {
81    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
82        let config =
83            serde_json::from_value::<PostgresConfig>(serde_json::to_value(properties).unwrap())
84                .map_err(|e| SinkError::Config(anyhow!(e)))?;
85        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
86            return Err(SinkError::Config(anyhow!(
87                "`{}` must be {}, or {}",
88                SINK_TYPE_OPTION,
89                SINK_TYPE_APPEND_ONLY,
90                SINK_TYPE_UPSERT
91            )));
92        }
93        Ok(config)
94    }
95}
96
97#[derive(Debug)]
98pub struct PostgresSink {
99    pub config: PostgresConfig,
100    schema: Schema,
101    pk_indices: Vec<usize>,
102    is_append_only: bool,
103}
104
105impl PostgresSink {
106    pub fn new(
107        config: PostgresConfig,
108        schema: Schema,
109        pk_indices: Vec<usize>,
110        is_append_only: bool,
111    ) -> Result<Self> {
112        Ok(Self {
113            config,
114            schema,
115            pk_indices,
116            is_append_only,
117        })
118    }
119}
120
121impl EnforceSecret for PostgresSink {
122    fn enforce_secret<'a>(
123        prop_iter: impl Iterator<Item = &'a str>,
124    ) -> crate::error::ConnectorResult<()> {
125        for prop in prop_iter {
126            PostgresConfig::enforce_one(prop)?;
127        }
128        Ok(())
129    }
130}
131
132impl TryFrom<SinkParam> for PostgresSink {
133    type Error = SinkError;
134
135    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
136        let schema = param.schema();
137        let config = PostgresConfig::from_btreemap(param.properties)?;
138        PostgresSink::new(
139            config,
140            schema,
141            param.downstream_pk,
142            param.sink_type.is_append_only(),
143        )
144    }
145}
146
147impl Sink for PostgresSink {
148    type Coordinator = DummySinkCommitCoordinator;
149    type LogSinker = PostgresSinkWriter;
150
151    const SINK_NAME: &'static str = POSTGRES_SINK;
152
153    async fn validate(&self) -> Result<()> {
154        if !self.is_append_only && self.pk_indices.is_empty() {
155            return Err(SinkError::Config(anyhow!(
156                "Primary key not defined for upsert Postgres sink (please define in `primary_key` field)"
157            )));
158        }
159
160        // Verify our sink schema is compatible with Postgres
161        {
162            let pg_table = PostgresExternalTable::connect(
163                &self.config.user,
164                &self.config.password,
165                &self.config.host,
166                self.config.port,
167                &self.config.database,
168                &self.config.schema,
169                &self.config.table,
170                &self.config.ssl_mode,
171                &self.config.ssl_root_cert,
172                self.is_append_only,
173            )
174            .await
175            .context(format!(
176                "failed to connect to database: {}, schema: {}, table: {}",
177                &self.config.database, &self.config.schema, &self.config.table
178            ))?;
179
180            // Check that names and types match, order of columns doesn't matter.
181            {
182                let pg_columns = pg_table.column_descs();
183                let sink_columns = self.schema.fields();
184                if pg_columns.len() < sink_columns.len() {
185                    return Err(SinkError::Config(anyhow!(
186                        "Column count mismatch: Postgres table has {} columns, but sink schema has {} columns, sink should have less or equal columns to the Postgres table",
187                        pg_columns.len(),
188                        sink_columns.len()
189                    )));
190                }
191
192                let pg_columns_lookup = pg_columns
193                    .iter()
194                    .map(|c| (c.name.clone(), c.data_type.clone()))
195                    .collect::<BTreeMap<_, _>>();
196                for sink_column in sink_columns {
197                    let pg_column = pg_columns_lookup.get(&sink_column.name);
198                    match pg_column {
199                        None => {
200                            return Err(SinkError::Config(anyhow!(
201                                "Column `{}` not found in Postgres table `{}`",
202                                sink_column.name,
203                                self.config.table
204                            )));
205                        }
206                        Some(pg_column) => {
207                            if !validate_pg_type_to_rw_type(pg_column, &sink_column.data_type()) {
208                                return Err(SinkError::Config(anyhow!(
209                                    "Column `{}` in Postgres table `{}` has type `{}`, but sink schema defines it as type `{}`",
210                                    sink_column.name,
211                                    self.config.table,
212                                    pg_column,
213                                    sink_column.data_type()
214                                )));
215                            }
216                        }
217                    }
218                }
219            }
220
221            // check that pk matches
222            {
223                let pg_pk_names = pg_table.pk_names();
224                let sink_pk_names = self
225                    .pk_indices
226                    .iter()
227                    .map(|i| &self.schema.fields()[*i].name)
228                    .collect::<HashSet<_>>();
229                if pg_pk_names.len() != sink_pk_names.len() {
230                    return Err(SinkError::Config(anyhow!(
231                        "Primary key mismatch: Postgres table has primary key on columns {:?}, but sink schema defines primary key on columns {:?}",
232                        pg_pk_names,
233                        sink_pk_names
234                    )));
235                }
236                for name in pg_pk_names {
237                    if !sink_pk_names.contains(name) {
238                        return Err(SinkError::Config(anyhow!(
239                            "Primary key mismatch: Postgres table has primary key on column `{}`, but sink schema does not define it as a primary key",
240                            name
241                        )));
242                    }
243                }
244            }
245        }
246
247        Ok(())
248    }
249
250    async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
251        PostgresSinkWriter::new(
252            self.config.clone(),
253            self.schema.clone(),
254            self.pk_indices.clone(),
255            self.is_append_only,
256        )
257        .await
258    }
259}
260
261pub struct PostgresSinkWriter {
262    is_append_only: bool,
263    client: tokio_postgres::Client,
264    pk_indices: Vec<usize>,
265    pk_types: Vec<PgType>,
266    schema_types: Vec<PgType>,
267    raw_insert_sql: Arc<String>,
268    raw_upsert_sql: Arc<String>,
269    raw_delete_sql: Arc<String>,
270    insert_sql: Arc<tokio_postgres::Statement>,
271    delete_sql: Arc<tokio_postgres::Statement>,
272    upsert_sql: Arc<tokio_postgres::Statement>,
273}
274
275impl PostgresSinkWriter {
276    async fn new(
277        config: PostgresConfig,
278        schema: Schema,
279        pk_indices: Vec<usize>,
280        is_append_only: bool,
281    ) -> Result<Self> {
282        let client = create_pg_client(
283            &config.user,
284            &config.password,
285            &config.host,
286            &config.port.to_string(),
287            &config.database,
288            &config.ssl_mode,
289            &config.ssl_root_cert,
290        )
291        .await?;
292
293        let pk_indices_lookup = pk_indices.iter().copied().collect::<HashSet<_>>();
294
295        // Rewrite schema types for serialization
296        let (pk_types, schema_types) = {
297            let name_to_type = PostgresExternalTable::type_mapping(
298                &config.user,
299                &config.password,
300                &config.host,
301                config.port,
302                &config.database,
303                &config.schema,
304                &config.table,
305                &config.ssl_mode,
306                &config.ssl_root_cert,
307                is_append_only,
308            )
309            .await?;
310            let mut schema_types = Vec::with_capacity(schema.fields.len());
311            let mut pk_types = Vec::with_capacity(pk_indices.len());
312            for (i, field) in schema.fields.iter().enumerate() {
313                let field_name = &field.name;
314                let actual_data_type = name_to_type.get(field_name).map(|t| (*t).clone());
315                let actual_data_type = actual_data_type
316                    .ok_or_else(|| {
317                        SinkError::Config(anyhow!(
318                            "Column `{}` not found in sink schema",
319                            field_name
320                        ))
321                    })?
322                    .clone();
323                if pk_indices_lookup.contains(&i) {
324                    pk_types.push(actual_data_type.clone())
325                }
326                schema_types.push(actual_data_type);
327            }
328            (pk_types, schema_types)
329        };
330
331        let raw_insert_sql = create_insert_sql(&schema, &config.schema, &config.table);
332        let raw_upsert_sql = create_upsert_sql(
333            &schema,
334            &config.schema,
335            &config.table,
336            &pk_indices,
337            &pk_indices_lookup,
338        );
339        let raw_delete_sql = create_delete_sql(&schema, &config.schema, &config.table, &pk_indices);
340
341        let insert_sql = client
342            .prepare(&raw_insert_sql)
343            .await
344            .with_context(|| format!("failed to prepare insert statement: {}", raw_insert_sql))?;
345        let upsert_sql = client
346            .prepare(&raw_upsert_sql)
347            .await
348            .with_context(|| format!("failed to prepare upsert statement: {}", raw_upsert_sql))?;
349        let delete_sql = client
350            .prepare(&raw_delete_sql)
351            .await
352            .with_context(|| format!("failed to prepare delete statement: {}", raw_delete_sql))?;
353
354        let writer = Self {
355            is_append_only,
356            client,
357            pk_indices,
358            pk_types,
359            schema_types,
360            raw_insert_sql: Arc::new(raw_insert_sql),
361            raw_upsert_sql: Arc::new(raw_upsert_sql),
362            raw_delete_sql: Arc::new(raw_delete_sql),
363            insert_sql: Arc::new(insert_sql),
364            delete_sql: Arc::new(delete_sql),
365            upsert_sql: Arc::new(upsert_sql),
366        };
367        Ok(writer)
368    }
369
370    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
371        // https://www.postgresql.org/docs/current/limits.html
372        // We have a limit of 65,535 parameters in a single query, as restricted by the PostgreSQL protocol.
373        if self.is_append_only {
374            self.write_batch_append_only(chunk).await
375        } else {
376            self.write_batch_non_append_only(chunk).await
377        }
378    }
379
380    async fn write_batch_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
381        let transaction = Arc::new(self.client.transaction().await?);
382        let mut insert_futures = FuturesUnordered::new();
383        for (op, row) in chunk.rows() {
384            match op {
385                Op::Insert => {
386                    let pg_row = convert_row_to_pg_row(row, &self.schema_types);
387                    let insert_sql = self.insert_sql.clone();
388                    let raw_insert_sql = self.raw_insert_sql.clone();
389                    let transaction = transaction.clone();
390                    let future = async move {
391                        transaction
392                            .execute_raw(insert_sql.as_ref(), &pg_row)
393                            .await
394                            .with_context(|| {
395                                format!(
396                                    "failed to execute insert statement: {}, parameters: {:?}",
397                                    raw_insert_sql, pg_row
398                                )
399                            })
400                    };
401                    insert_futures.push(future);
402                }
403                _ => {
404                    tracing::error!(
405                        "row ignored, append-only sink should not receive update insert, update delete and delete operations"
406                    );
407                }
408            }
409        }
410
411        while let Some(result) = insert_futures.next().await {
412            result?;
413        }
414        if let Some(transaction) = Arc::into_inner(transaction) {
415            transaction.commit().await?;
416        } else {
417            tracing::error!("transaction lost!");
418        }
419
420        Ok(())
421    }
422
423    async fn write_batch_non_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
424        let transaction = Arc::new(self.client.transaction().await?);
425        let mut delete_futures = FuturesUnordered::new();
426        let mut upsert_futures = FuturesUnordered::new();
427        for (op, row) in chunk.rows() {
428            match op {
429                Op::Delete | Op::UpdateDelete => {
430                    let pg_row =
431                        convert_row_to_pg_row(row.project(&self.pk_indices), &self.pk_types);
432                    let delete_sql = self.delete_sql.clone();
433                    let raw_delete_sql = self.raw_delete_sql.clone();
434                    let transaction = transaction.clone();
435                    let future = async move {
436                        transaction
437                            .execute_raw(delete_sql.as_ref(), &pg_row)
438                            .await
439                            .with_context(|| {
440                                format!(
441                                    "failed to execute delete statement: {}, parameters: {:?}",
442                                    raw_delete_sql, pg_row
443                                )
444                            })
445                    };
446                    delete_futures.push(future);
447                }
448                Op::Insert | Op::UpdateInsert => {
449                    let pg_row = convert_row_to_pg_row(row, &self.schema_types);
450                    let upsert_sql = self.upsert_sql.clone();
451                    let raw_upsert_sql = self.raw_upsert_sql.clone();
452                    let transaction = transaction.clone();
453                    let future = async move {
454                        transaction
455                            .execute_raw(upsert_sql.as_ref(), &pg_row)
456                            .await
457                            .with_context(|| {
458                                format!(
459                                    "failed to execute upsert statement: {}, parameters: {:?}",
460                                    raw_upsert_sql, pg_row
461                                )
462                            })
463                    };
464                    upsert_futures.push(future);
465                }
466            }
467        }
468        while let Some(result) = delete_futures.next().await {
469            result?;
470        }
471        while let Some(result) = upsert_futures.next().await {
472            result?;
473        }
474        if let Some(transaction) = Arc::into_inner(transaction) {
475            transaction.commit().await?;
476        } else {
477            tracing::error!("transaction lost!");
478        }
479        Ok(())
480    }
481}
482
483#[async_trait]
484impl LogSinker for PostgresSinkWriter {
485    async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
486        log_reader.start_from(None).await?;
487        loop {
488            let (epoch, item) = log_reader.next_item().await?;
489            match item {
490                LogStoreReadItem::StreamChunk { chunk, chunk_id } => {
491                    self.write_batch(chunk).await?;
492                    log_reader.truncate(TruncateOffset::Chunk { epoch, chunk_id })?;
493                }
494                LogStoreReadItem::Barrier { .. } => {
495                    log_reader.truncate(TruncateOffset::Barrier { epoch })?;
496                }
497            }
498        }
499    }
500}
501
502fn create_insert_sql(schema: &Schema, schema_name: &str, table_name: &str) -> String {
503    let normalized_table_name = format!(
504        "{}.{}",
505        quote_identifier(schema_name),
506        quote_identifier(table_name)
507    );
508    let number_of_columns = schema.len();
509    let columns: String = schema
510        .fields()
511        .iter()
512        .map(|field| quote_identifier(&field.name))
513        .join(", ");
514    let column_parameters: String = (0..number_of_columns)
515        .map(|i| format!("${}", i + 1))
516        .join(", ");
517    format!("INSERT INTO {normalized_table_name} ({columns}) VALUES ({column_parameters})")
518}
519
520fn create_delete_sql(
521    schema: &Schema,
522    schema_name: &str,
523    table_name: &str,
524    pk_indices: &[usize],
525) -> String {
526    let normalized_table_name = format!(
527        "{}.{}",
528        quote_identifier(schema_name),
529        quote_identifier(table_name)
530    );
531    let pk_indices = if pk_indices.is_empty() {
532        (0..schema.len()).collect_vec()
533    } else {
534        pk_indices.to_vec()
535    };
536    let pk = {
537        let pk_symbols = pk_indices
538            .iter()
539            .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
540            .join(", ");
541        format!("({})", pk_symbols)
542    };
543    let parameters: String = (0..pk_indices.len())
544        .map(|i| format!("${}", i + 1))
545        .join(", ");
546    format!("DELETE FROM {normalized_table_name} WHERE {pk} in (({parameters}))")
547}
548
549fn create_upsert_sql(
550    schema: &Schema,
551    schema_name: &str,
552    table_name: &str,
553    pk_indices: &[usize],
554    pk_indices_lookup: &HashSet<usize>,
555) -> String {
556    let insert_sql = create_insert_sql(schema, schema_name, table_name);
557    if pk_indices.is_empty() {
558        return insert_sql;
559    }
560    let pk_columns = pk_indices
561        .iter()
562        .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
563        .collect_vec()
564        .join(", ");
565    let update_parameters: String = (0..schema.len())
566        .filter(|i| !pk_indices_lookup.contains(i))
567        .map(|i| {
568            let column = quote_identifier(&schema.fields()[i].name);
569            format!("{column} = EXCLUDED.{column}")
570        })
571        .collect_vec()
572        .join(", ");
573    format!("{insert_sql} on conflict ({pk_columns}) do update set {update_parameters}")
574}
575
576/// Quote an identifier for PostgreSQL.
577fn quote_identifier(identifier: &str) -> String {
578    format!("\"{}\"", identifier.replace("\"", "\"\""))
579}
580
581type PgDatum = Option<ScalarAdapter>;
582type PgRow = Vec<PgDatum>;
583
584fn convert_row_to_pg_row(row: impl Row, schema_types: &[PgType]) -> PgRow {
585    let mut buffer = Vec::with_capacity(row.len());
586    for (i, datum_ref) in row.iter().enumerate() {
587        let pg_datum = datum_ref.map(|s| {
588            match ScalarAdapter::from_scalar(s, &schema_types[i]) {
589                Ok(scalar) => Some(scalar),
590                Err(e) => {
591                    tracing::error!(error=%e.as_report(), scalar=?s, "Failed to convert scalar to pg value");
592                    None
593                }
594            }
595        });
596        buffer.push(pg_datum.flatten());
597    }
598    buffer
599}
600
601#[cfg(test)]
602mod tests {
603    use std::fmt::Display;
604
605    use expect_test::{Expect, expect};
606    use risingwave_common::catalog::Field;
607    use risingwave_common::types::DataType;
608
609    use super::*;
610
611    fn check(actual: impl Display, expect: Expect) {
612        let actual = actual.to_string();
613        expect.assert_eq(&actual);
614    }
615
616    #[test]
617    fn test_create_insert_sql() {
618        let schema = Schema::new(vec![
619            Field {
620                data_type: DataType::Int32,
621                name: "a".to_owned(),
622            },
623            Field {
624                data_type: DataType::Int32,
625                name: "b".to_owned(),
626            },
627        ]);
628        let schema_name = "test_schema";
629        let table_name = "test_table";
630        let sql = create_insert_sql(&schema, schema_name, table_name);
631        check(
632            sql,
633            expect![[r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2)"#]],
634        );
635    }
636
637    #[test]
638    fn test_create_delete_sql() {
639        let schema = Schema::new(vec![
640            Field {
641                data_type: DataType::Int32,
642                name: "a".to_owned(),
643            },
644            Field {
645                data_type: DataType::Int32,
646                name: "b".to_owned(),
647            },
648        ]);
649        let schema_name = "test_schema";
650        let table_name = "test_table";
651        let sql = create_delete_sql(&schema, schema_name, table_name, &[1]);
652        check(
653            sql,
654            expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("b") in (($1))"#]],
655        );
656        let table_name = "test_table";
657        let sql = create_delete_sql(&schema, schema_name, table_name, &[0, 1]);
658        check(
659            sql,
660            expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("a", "b") in (($1, $2))"#]],
661        );
662    }
663
664    #[test]
665    fn test_create_upsert_sql() {
666        let schema = Schema::new(vec![
667            Field {
668                data_type: DataType::Int32,
669                name: "a".to_owned(),
670            },
671            Field {
672                data_type: DataType::Int32,
673                name: "b".to_owned(),
674            },
675        ]);
676        let schema_name = "test_schema";
677        let table_name = "test_table";
678        let pk_indices_lookup = HashSet::from_iter([1]);
679        let sql = create_upsert_sql(&schema, schema_name, table_name, &[1], &pk_indices_lookup);
680        check(
681            sql,
682            expect![[
683                r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2) on conflict ("b") do update set "a" = EXCLUDED."a""#
684            ]],
685        );
686    }
687}