risingwave_connector/sink/
postgres.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashSet};
16use std::sync::Arc;
17
18use anyhow::{Context, anyhow};
19use async_trait::async_trait;
20use futures::StreamExt;
21use futures::stream::FuturesUnordered;
22use itertools::Itertools;
23use phf::phf_set;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::row::{Row, RowExt};
27use serde::Deserialize;
28use serde_with::{DisplayFromStr, serde_as};
29use simd_json::prelude::ArrayTrait;
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Type as PgType;
32
33use super::{
34    LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
35};
36use crate::connector_common::{PostgresExternalTable, SslMode, create_pg_client};
37use crate::enforce_secret::EnforceSecret;
38use crate::parser::scalar_adapter::{ScalarAdapter, validate_pg_type_to_rw_type};
39use crate::sink::log_store::{LogStoreReadItem, TruncateOffset};
40use crate::sink::{Result, Sink, SinkParam, SinkWriterParam};
41
42pub const POSTGRES_SINK: &str = "postgres";
43
44#[serde_as]
45#[derive(Clone, Debug, Deserialize)]
46pub struct PostgresConfig {
47    pub host: String,
48    #[serde_as(as = "DisplayFromStr")]
49    pub port: u16,
50    pub user: String,
51    pub password: String,
52    pub database: String,
53    pub table: String,
54    #[serde(default = "default_schema")]
55    pub schema: String,
56    #[serde(default = "Default::default")]
57    pub ssl_mode: SslMode,
58    #[serde(rename = "ssl.root.cert")]
59    pub ssl_root_cert: Option<String>,
60    #[serde(default = "default_max_batch_rows")]
61    #[serde_as(as = "DisplayFromStr")]
62    pub max_batch_rows: usize,
63    pub r#type: String, // accept "append-only" or "upsert"
64    #[serde(default, rename = "tcp.keepalive.enable")]
65    #[serde_as(as = "DisplayFromStr")]
66    pub tcp_keepalive_enable: bool,
67
68    #[serde(flatten)]
69    pub tcp_keepalive: Option<TcpKeepaliveConfig>,
70}
71
72#[serde_as]
73#[derive(Debug, Clone, Deserialize)]
74pub struct TcpKeepaliveConfig {
75    #[serde(rename = "tcp.keepalive.idle")]
76    #[serde_as(as = "DisplayFromStr")]
77    pub tcp_keepalive_idle: u32,
78    #[serde(rename = "tcp.keepalive.interval")]
79    #[serde_as(as = "DisplayFromStr")]
80    pub tcp_keepalive_interval: u32,
81    #[serde(rename = "tcp.keepalive.count")]
82    #[serde_as(as = "DisplayFromStr")]
83    pub tcp_keepalive_count: u32,
84}
85
86impl Default for TcpKeepaliveConfig {
87    fn default() -> Self {
88        Self {
89            tcp_keepalive_idle: 10 * 60, // 10 minutes,
90            tcp_keepalive_interval: 10,
91            tcp_keepalive_count: 3,
92        }
93    }
94}
95
96impl EnforceSecret for PostgresConfig {
97    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf_set! {
98        "password", "ssl.root.cert"
99    };
100}
101
102fn default_max_batch_rows() -> usize {
103    1024
104}
105
106fn default_schema() -> String {
107    "public".to_owned()
108}
109
110impl PostgresConfig {
111    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
112        let config =
113            serde_json::from_value::<PostgresConfig>(serde_json::to_value(properties).unwrap())
114                .map_err(|e| SinkError::Config(anyhow!(e)))?;
115        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
116            return Err(SinkError::Config(anyhow!(
117                "`{}` must be {}, or {}",
118                SINK_TYPE_OPTION,
119                SINK_TYPE_APPEND_ONLY,
120                SINK_TYPE_UPSERT
121            )));
122        }
123        Ok(config)
124    }
125}
126
127#[derive(Debug)]
128pub struct PostgresSink {
129    pub config: PostgresConfig,
130    schema: Schema,
131    pk_indices: Vec<usize>,
132    is_append_only: bool,
133}
134
135impl PostgresSink {
136    pub fn new(
137        config: PostgresConfig,
138        schema: Schema,
139        pk_indices: Vec<usize>,
140        is_append_only: bool,
141    ) -> Result<Self> {
142        Ok(Self {
143            config,
144            schema,
145            pk_indices,
146            is_append_only,
147        })
148    }
149}
150
151impl EnforceSecret for PostgresSink {
152    fn enforce_secret<'a>(
153        prop_iter: impl Iterator<Item = &'a str>,
154    ) -> crate::error::ConnectorResult<()> {
155        for prop in prop_iter {
156            PostgresConfig::enforce_one(prop)?;
157        }
158        Ok(())
159    }
160}
161
162impl TryFrom<SinkParam> for PostgresSink {
163    type Error = SinkError;
164
165    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
166        let schema = param.schema();
167        let pk_indices = param.downstream_pk_or_empty();
168        let config = PostgresConfig::from_btreemap(param.properties)?;
169        PostgresSink::new(config, schema, pk_indices, param.sink_type.is_append_only())
170    }
171}
172
173impl Sink for PostgresSink {
174    type LogSinker = PostgresSinkWriter;
175
176    const SINK_NAME: &'static str = POSTGRES_SINK;
177
178    async fn validate(&self) -> Result<()> {
179        if !self.is_append_only && self.pk_indices.is_empty() {
180            return Err(SinkError::Config(anyhow!(
181                "Primary key not defined for upsert Postgres sink (please define in `primary_key` field)"
182            )));
183        }
184
185        // Verify our sink schema is compatible with Postgres
186        {
187            let pg_table = PostgresExternalTable::connect(
188                &self.config.user,
189                &self.config.password,
190                &self.config.host,
191                self.config.port,
192                &self.config.database,
193                &self.config.schema,
194                &self.config.table,
195                &self.config.ssl_mode,
196                &self.config.ssl_root_cert,
197                self.is_append_only,
198            )
199            .await
200            .context(format!(
201                "failed to connect to database: {}, schema: {}, table: {}",
202                &self.config.database, &self.config.schema, &self.config.table
203            ))?;
204
205            // Check that names and types match, order of columns doesn't matter.
206            {
207                let pg_columns = pg_table.column_descs();
208                let sink_columns = self.schema.fields();
209                if pg_columns.len() < sink_columns.len() {
210                    return Err(SinkError::Config(anyhow!(
211                        "Column count mismatch: Postgres table has {} columns, but sink schema has {} columns, sink should have less or equal columns to the Postgres table",
212                        pg_columns.len(),
213                        sink_columns.len()
214                    )));
215                }
216
217                let pg_columns_lookup = pg_columns
218                    .iter()
219                    .map(|c| (c.name.clone(), c.data_type.clone()))
220                    .collect::<BTreeMap<_, _>>();
221                for sink_column in sink_columns {
222                    let pg_column = pg_columns_lookup.get(&sink_column.name);
223                    match pg_column {
224                        None => {
225                            return Err(SinkError::Config(anyhow!(
226                                "Column `{}` not found in Postgres table `{}`",
227                                sink_column.name,
228                                self.config.table
229                            )));
230                        }
231                        Some(pg_column) => {
232                            if !validate_pg_type_to_rw_type(pg_column, &sink_column.data_type()) {
233                                return Err(SinkError::Config(anyhow!(
234                                    "Column `{}` in Postgres table `{}` has type `{}`, but sink schema defines it as type `{}`",
235                                    sink_column.name,
236                                    self.config.table,
237                                    pg_column,
238                                    sink_column.data_type()
239                                )));
240                            }
241                        }
242                    }
243                }
244            }
245
246            // check that pk matches
247            {
248                let pg_pk_names = pg_table.pk_names();
249                let sink_pk_names = self
250                    .pk_indices
251                    .iter()
252                    .map(|i| &self.schema.fields()[*i].name)
253                    .collect::<HashSet<_>>();
254                if pg_pk_names.len() != sink_pk_names.len() {
255                    return Err(SinkError::Config(anyhow!(
256                        "Primary key mismatch: Postgres table has primary key on columns {:?}, but sink schema defines primary key on columns {:?}",
257                        pg_pk_names,
258                        sink_pk_names
259                    )));
260                }
261                for name in pg_pk_names {
262                    if !sink_pk_names.contains(name) {
263                        return Err(SinkError::Config(anyhow!(
264                            "Primary key mismatch: Postgres table has primary key on column `{}`, but sink schema does not define it as a primary key",
265                            name
266                        )));
267                    }
268                }
269            }
270        }
271
272        Ok(())
273    }
274
275    async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
276        PostgresSinkWriter::new(
277            self.config.clone(),
278            self.schema.clone(),
279            self.pk_indices.clone(),
280            self.is_append_only,
281        )
282        .await
283    }
284}
285
286pub struct PostgresSinkWriter {
287    is_append_only: bool,
288    client: tokio_postgres::Client,
289    pk_indices: Vec<usize>,
290    pk_types: Vec<PgType>,
291    schema_types: Vec<PgType>,
292    raw_insert_sql: Arc<String>,
293    raw_upsert_sql: Arc<String>,
294    raw_delete_sql: Arc<String>,
295    insert_sql: Arc<tokio_postgres::Statement>,
296    delete_sql: Arc<tokio_postgres::Statement>,
297    upsert_sql: Arc<tokio_postgres::Statement>,
298}
299
300impl PostgresSinkWriter {
301    async fn new(
302        config: PostgresConfig,
303        schema: Schema,
304        pk_indices: Vec<usize>,
305        is_append_only: bool,
306    ) -> Result<Self> {
307        let tcp_keepalive = if config.tcp_keepalive_enable {
308            config
309                .tcp_keepalive
310                .or_else(|| Some(TcpKeepaliveConfig::default()))
311        } else {
312            None
313        };
314
315        let client = create_pg_client(
316            &config.user,
317            &config.password,
318            &config.host,
319            &config.port.to_string(),
320            &config.database,
321            &config.ssl_mode,
322            &config.ssl_root_cert,
323            tcp_keepalive,
324        )
325        .await?;
326
327        let pk_indices_lookup = pk_indices.iter().copied().collect::<HashSet<_>>();
328
329        // Rewrite schema types for serialization
330        let (pk_types, schema_types) = {
331            let name_to_type = PostgresExternalTable::type_mapping(
332                &config.user,
333                &config.password,
334                &config.host,
335                config.port,
336                &config.database,
337                &config.schema,
338                &config.table,
339                &config.ssl_mode,
340                &config.ssl_root_cert,
341                is_append_only,
342            )
343            .await?;
344            let mut schema_types = Vec::with_capacity(schema.fields.len());
345            let mut pk_types = Vec::with_capacity(pk_indices.len());
346            for (i, field) in schema.fields.iter().enumerate() {
347                let field_name = &field.name;
348                let actual_data_type = name_to_type.get(field_name).map(|t| (*t).clone());
349                let actual_data_type = actual_data_type
350                    .ok_or_else(|| {
351                        SinkError::Config(anyhow!(
352                            "Column `{}` not found in sink schema",
353                            field_name
354                        ))
355                    })?
356                    .clone();
357                if pk_indices_lookup.contains(&i) {
358                    pk_types.push(actual_data_type.clone())
359                }
360                schema_types.push(actual_data_type);
361            }
362            (pk_types, schema_types)
363        };
364
365        let raw_insert_sql = create_insert_sql(&schema, &config.schema, &config.table);
366        let raw_upsert_sql = create_upsert_sql(
367            &schema,
368            &config.schema,
369            &config.table,
370            &pk_indices,
371            &pk_indices_lookup,
372        );
373        let raw_delete_sql = create_delete_sql(&schema, &config.schema, &config.table, &pk_indices);
374
375        let insert_sql = client
376            .prepare(&raw_insert_sql)
377            .await
378            .with_context(|| format!("failed to prepare insert statement: {}", raw_insert_sql))?;
379        let upsert_sql = client
380            .prepare(&raw_upsert_sql)
381            .await
382            .with_context(|| format!("failed to prepare upsert statement: {}", raw_upsert_sql))?;
383        let delete_sql = client
384            .prepare(&raw_delete_sql)
385            .await
386            .with_context(|| format!("failed to prepare delete statement: {}", raw_delete_sql))?;
387
388        let writer = Self {
389            is_append_only,
390            client,
391            pk_indices,
392            pk_types,
393            schema_types,
394            raw_insert_sql: Arc::new(raw_insert_sql),
395            raw_upsert_sql: Arc::new(raw_upsert_sql),
396            raw_delete_sql: Arc::new(raw_delete_sql),
397            insert_sql: Arc::new(insert_sql),
398            delete_sql: Arc::new(delete_sql),
399            upsert_sql: Arc::new(upsert_sql),
400        };
401        Ok(writer)
402    }
403
404    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
405        // https://www.postgresql.org/docs/current/limits.html
406        // We have a limit of 65,535 parameters in a single query, as restricted by the PostgreSQL protocol.
407        if self.is_append_only {
408            self.write_batch_append_only(chunk).await
409        } else {
410            self.write_batch_non_append_only(chunk).await
411        }
412    }
413
414    async fn write_batch_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
415        let transaction = Arc::new(self.client.transaction().await?);
416        let mut insert_futures = FuturesUnordered::new();
417        for (op, row) in chunk.rows() {
418            match op {
419                Op::Insert => {
420                    let pg_row = convert_row_to_pg_row(row, &self.schema_types);
421                    let insert_sql = self.insert_sql.clone();
422                    let raw_insert_sql = self.raw_insert_sql.clone();
423                    let transaction = transaction.clone();
424                    let future = async move {
425                        transaction
426                            .execute_raw(insert_sql.as_ref(), &pg_row)
427                            .await
428                            .with_context(|| {
429                                format!(
430                                    "failed to execute insert statement: {}, parameters: {:?}",
431                                    raw_insert_sql, pg_row
432                                )
433                            })
434                    };
435                    insert_futures.push(future);
436                }
437                _ => {
438                    tracing::error!(
439                        "row ignored, append-only sink should not receive update insert, update delete and delete operations"
440                    );
441                }
442            }
443        }
444
445        while let Some(result) = insert_futures.next().await {
446            result?;
447        }
448        if let Some(transaction) = Arc::into_inner(transaction) {
449            transaction.commit().await?;
450        } else {
451            tracing::error!("transaction lost!");
452        }
453
454        Ok(())
455    }
456
457    async fn write_batch_non_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
458        let transaction = Arc::new(self.client.transaction().await?);
459        let mut delete_futures = FuturesUnordered::new();
460        let mut upsert_futures = FuturesUnordered::new();
461        for (op, row) in chunk.rows() {
462            match op {
463                Op::Delete | Op::UpdateDelete => {
464                    let pg_row =
465                        convert_row_to_pg_row(row.project(&self.pk_indices), &self.pk_types);
466                    let delete_sql = self.delete_sql.clone();
467                    let raw_delete_sql = self.raw_delete_sql.clone();
468                    let transaction = transaction.clone();
469                    let future = async move {
470                        transaction
471                            .execute_raw(delete_sql.as_ref(), &pg_row)
472                            .await
473                            .with_context(|| {
474                                format!(
475                                    "failed to execute delete statement: {}, parameters: {:?}",
476                                    raw_delete_sql, pg_row
477                                )
478                            })
479                    };
480                    delete_futures.push(future);
481                }
482                Op::Insert | Op::UpdateInsert => {
483                    let pg_row = convert_row_to_pg_row(row, &self.schema_types);
484                    let upsert_sql = self.upsert_sql.clone();
485                    let raw_upsert_sql = self.raw_upsert_sql.clone();
486                    let transaction = transaction.clone();
487                    let future = async move {
488                        transaction
489                            .execute_raw(upsert_sql.as_ref(), &pg_row)
490                            .await
491                            .with_context(|| {
492                                format!(
493                                    "failed to execute upsert statement: {}, parameters: {:?}",
494                                    raw_upsert_sql, pg_row
495                                )
496                            })
497                    };
498                    upsert_futures.push(future);
499                }
500            }
501        }
502        while let Some(result) = delete_futures.next().await {
503            result?;
504        }
505        while let Some(result) = upsert_futures.next().await {
506            result?;
507        }
508        if let Some(transaction) = Arc::into_inner(transaction) {
509            transaction.commit().await?;
510        } else {
511            tracing::error!("transaction lost!");
512        }
513        Ok(())
514    }
515}
516
517#[async_trait]
518impl LogSinker for PostgresSinkWriter {
519    async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
520        log_reader.start_from(None).await?;
521        loop {
522            let (epoch, item) = log_reader.next_item().await?;
523            match item {
524                LogStoreReadItem::StreamChunk { chunk, chunk_id } => {
525                    self.write_batch(chunk).await?;
526                    log_reader.truncate(TruncateOffset::Chunk { epoch, chunk_id })?;
527                }
528                LogStoreReadItem::Barrier { .. } => {
529                    log_reader.truncate(TruncateOffset::Barrier { epoch })?;
530                }
531            }
532        }
533    }
534}
535
536fn create_insert_sql(schema: &Schema, schema_name: &str, table_name: &str) -> String {
537    let normalized_table_name = format!(
538        "{}.{}",
539        quote_identifier(schema_name),
540        quote_identifier(table_name)
541    );
542    let number_of_columns = schema.len();
543    let columns: String = schema
544        .fields()
545        .iter()
546        .map(|field| quote_identifier(&field.name))
547        .join(", ");
548    let column_parameters: String = (0..number_of_columns)
549        .map(|i| format!("${}", i + 1))
550        .join(", ");
551    format!("INSERT INTO {normalized_table_name} ({columns}) VALUES ({column_parameters})")
552}
553
554fn create_delete_sql(
555    schema: &Schema,
556    schema_name: &str,
557    table_name: &str,
558    pk_indices: &[usize],
559) -> String {
560    let normalized_table_name = format!(
561        "{}.{}",
562        quote_identifier(schema_name),
563        quote_identifier(table_name)
564    );
565    let pk_indices = if pk_indices.is_empty() {
566        (0..schema.len()).collect_vec()
567    } else {
568        pk_indices.to_vec()
569    };
570    let pk = {
571        let pk_symbols = pk_indices
572            .iter()
573            .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
574            .join(", ");
575        format!("({})", pk_symbols)
576    };
577    let parameters: String = (0..pk_indices.len())
578        .map(|i| format!("${}", i + 1))
579        .join(", ");
580    format!("DELETE FROM {normalized_table_name} WHERE {pk} in (({parameters}))")
581}
582
583fn create_upsert_sql(
584    schema: &Schema,
585    schema_name: &str,
586    table_name: &str,
587    pk_indices: &[usize],
588    pk_indices_lookup: &HashSet<usize>,
589) -> String {
590    let insert_sql = create_insert_sql(schema, schema_name, table_name);
591    if pk_indices.is_empty() {
592        return insert_sql;
593    }
594    let pk_columns = pk_indices
595        .iter()
596        .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
597        .collect_vec()
598        .join(", ");
599    let update_parameters: String = (0..schema.len())
600        .filter(|i| !pk_indices_lookup.contains(i))
601        .map(|i| {
602            let column = quote_identifier(&schema.fields()[i].name);
603            format!("{column} = EXCLUDED.{column}")
604        })
605        .collect_vec()
606        .join(", ");
607    format!("{insert_sql} on conflict ({pk_columns}) do update set {update_parameters}")
608}
609
610/// Quote an identifier for PostgreSQL.
611fn quote_identifier(identifier: &str) -> String {
612    format!("\"{}\"", identifier.replace("\"", "\"\""))
613}
614
615type PgDatum = Option<ScalarAdapter>;
616type PgRow = Vec<PgDatum>;
617
618fn convert_row_to_pg_row(row: impl Row, schema_types: &[PgType]) -> PgRow {
619    let mut buffer = Vec::with_capacity(row.len());
620    for (i, datum_ref) in row.iter().enumerate() {
621        let pg_datum = datum_ref.map(|s| {
622            match ScalarAdapter::from_scalar(s, &schema_types[i]) {
623                Ok(scalar) => Some(scalar),
624                Err(e) => {
625                    tracing::error!(error=%e.as_report(), scalar=?s, "Failed to convert scalar to pg value");
626                    None
627                }
628            }
629        });
630        buffer.push(pg_datum.flatten());
631    }
632    buffer
633}
634
635#[cfg(test)]
636mod tests {
637    use std::fmt::Display;
638
639    use expect_test::{Expect, expect};
640    use risingwave_common::catalog::Field;
641    use risingwave_common::types::DataType;
642
643    use super::*;
644
645    fn check(actual: impl Display, expect: Expect) {
646        let actual = actual.to_string();
647        expect.assert_eq(&actual);
648    }
649
650    #[test]
651    fn test_create_insert_sql() {
652        let schema = Schema::new(vec![
653            Field {
654                data_type: DataType::Int32,
655                name: "a".to_owned(),
656            },
657            Field {
658                data_type: DataType::Int32,
659                name: "b".to_owned(),
660            },
661        ]);
662        let schema_name = "test_schema";
663        let table_name = "test_table";
664        let sql = create_insert_sql(&schema, schema_name, table_name);
665        check(
666            sql,
667            expect![[r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2)"#]],
668        );
669    }
670
671    #[test]
672    fn test_create_delete_sql() {
673        let schema = Schema::new(vec![
674            Field {
675                data_type: DataType::Int32,
676                name: "a".to_owned(),
677            },
678            Field {
679                data_type: DataType::Int32,
680                name: "b".to_owned(),
681            },
682        ]);
683        let schema_name = "test_schema";
684        let table_name = "test_table";
685        let sql = create_delete_sql(&schema, schema_name, table_name, &[1]);
686        check(
687            sql,
688            expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("b") in (($1))"#]],
689        );
690        let table_name = "test_table";
691        let sql = create_delete_sql(&schema, schema_name, table_name, &[0, 1]);
692        check(
693            sql,
694            expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("a", "b") in (($1, $2))"#]],
695        );
696    }
697
698    #[test]
699    fn test_create_upsert_sql() {
700        let schema = Schema::new(vec![
701            Field {
702                data_type: DataType::Int32,
703                name: "a".to_owned(),
704            },
705            Field {
706                data_type: DataType::Int32,
707                name: "b".to_owned(),
708            },
709        ]);
710        let schema_name = "test_schema";
711        let table_name = "test_table";
712        let pk_indices_lookup = HashSet::from_iter([1]);
713        let sql = create_upsert_sql(&schema, schema_name, table_name, &[1], &pk_indices_lookup);
714        check(
715            sql,
716            expect![[
717                r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2) on conflict ("b") do update set "a" = EXCLUDED."a""#
718            ]],
719        );
720    }
721}