risingwave_connector/sink/
postgres.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashSet};
16
17use anyhow::anyhow;
18use async_trait::async_trait;
19use itertools::Itertools;
20use risingwave_common::array::{Op, StreamChunk};
21use risingwave_common::bail;
22use risingwave_common::catalog::Schema;
23use risingwave_common::row::{Row, RowExt};
24use serde_derive::Deserialize;
25use serde_with::{DisplayFromStr, serde_as};
26use simd_json::prelude::ArrayTrait;
27use thiserror_ext::AsReport;
28use tokio_postgres::types::Type as PgType;
29
30use super::{
31    LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
32};
33use crate::connector_common::{PostgresExternalTable, SslMode, create_pg_client};
34use crate::parser::scalar_adapter::{ScalarAdapter, validate_pg_type_to_rw_type};
35use crate::sink::log_store::{LogStoreReadItem, TruncateOffset};
36use crate::sink::{DummySinkCommitCoordinator, Result, Sink, SinkParam, SinkWriterParam};
37
38pub const POSTGRES_SINK: &str = "postgres";
39
40#[serde_as]
41#[derive(Clone, Debug, Deserialize)]
42pub struct PostgresConfig {
43    pub host: String,
44    #[serde_as(as = "DisplayFromStr")]
45    pub port: u16,
46    pub user: String,
47    pub password: String,
48    pub database: String,
49    pub table: String,
50    #[serde(default = "default_schema")]
51    pub schema: String,
52    #[serde(default = "Default::default")]
53    pub ssl_mode: SslMode,
54    #[serde(rename = "ssl.root.cert")]
55    pub ssl_root_cert: Option<String>,
56    #[serde(default = "default_max_batch_rows")]
57    #[serde_as(as = "DisplayFromStr")]
58    pub max_batch_rows: usize,
59    pub r#type: String, // accept "append-only" or "upsert"
60}
61
62fn default_max_batch_rows() -> usize {
63    1024
64}
65
66fn default_schema() -> String {
67    "public".to_owned()
68}
69
70impl PostgresConfig {
71    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
72        let config =
73            serde_json::from_value::<PostgresConfig>(serde_json::to_value(properties).unwrap())
74                .map_err(|e| SinkError::Config(anyhow!(e)))?;
75        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
76            return Err(SinkError::Config(anyhow!(
77                "`{}` must be {}, or {}",
78                SINK_TYPE_OPTION,
79                SINK_TYPE_APPEND_ONLY,
80                SINK_TYPE_UPSERT
81            )));
82        }
83        Ok(config)
84    }
85}
86
87#[derive(Debug)]
88pub struct PostgresSink {
89    pub config: PostgresConfig,
90    schema: Schema,
91    pk_indices: Vec<usize>,
92    is_append_only: bool,
93}
94
95impl PostgresSink {
96    pub fn new(
97        config: PostgresConfig,
98        schema: Schema,
99        pk_indices: Vec<usize>,
100        is_append_only: bool,
101    ) -> Result<Self> {
102        Ok(Self {
103            config,
104            schema,
105            pk_indices,
106            is_append_only,
107        })
108    }
109}
110
111impl TryFrom<SinkParam> for PostgresSink {
112    type Error = SinkError;
113
114    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
115        let schema = param.schema();
116        let config = PostgresConfig::from_btreemap(param.properties)?;
117        PostgresSink::new(
118            config,
119            schema,
120            param.downstream_pk,
121            param.sink_type.is_append_only(),
122        )
123    }
124}
125
126impl Sink for PostgresSink {
127    type Coordinator = DummySinkCommitCoordinator;
128    type LogSinker = PostgresSinkWriter;
129
130    const SINK_NAME: &'static str = POSTGRES_SINK;
131
132    async fn validate(&self) -> Result<()> {
133        if !self.is_append_only && self.pk_indices.is_empty() {
134            return Err(SinkError::Config(anyhow!(
135                "Primary key not defined for upsert Postgres sink (please define in `primary_key` field)"
136            )));
137        }
138
139        // Verify our sink schema is compatible with Postgres
140        {
141            let pg_table = PostgresExternalTable::connect(
142                &self.config.user,
143                &self.config.password,
144                &self.config.host,
145                self.config.port,
146                &self.config.database,
147                &self.config.schema,
148                &self.config.table,
149                &self.config.ssl_mode,
150                &self.config.ssl_root_cert,
151                self.is_append_only,
152            )
153            .await?;
154
155            // Check that names and types match, order of columns doesn't matter.
156            {
157                let pg_columns = pg_table.column_descs();
158                let sink_columns = self.schema.fields();
159                if pg_columns.len() < sink_columns.len() {
160                    return Err(SinkError::Config(anyhow!(
161                        "Column count mismatch: Postgres table has {} columns, but sink schema has {} columns, sink should have less or equal columns to the Postgres table",
162                        pg_columns.len(),
163                        sink_columns.len()
164                    )));
165                }
166
167                let pg_columns_lookup = pg_columns
168                    .iter()
169                    .map(|c| (c.name.clone(), c.data_type.clone()))
170                    .collect::<BTreeMap<_, _>>();
171                for sink_column in sink_columns {
172                    let pg_column = pg_columns_lookup.get(&sink_column.name);
173                    match pg_column {
174                        None => {
175                            return Err(SinkError::Config(anyhow!(
176                                "Column `{}` not found in Postgres table `{}`",
177                                sink_column.name,
178                                self.config.table
179                            )));
180                        }
181                        Some(pg_column) => {
182                            if !validate_pg_type_to_rw_type(pg_column, &sink_column.data_type()) {
183                                return Err(SinkError::Config(anyhow!(
184                                    "Column `{}` in Postgres table `{}` has type `{}`, but sink schema defines it as type `{}`",
185                                    sink_column.name,
186                                    self.config.table,
187                                    pg_column,
188                                    sink_column.data_type()
189                                )));
190                            }
191                        }
192                    }
193                }
194            }
195
196            // check that pk matches
197            {
198                let pg_pk_names = pg_table.pk_names();
199                let sink_pk_names = self
200                    .pk_indices
201                    .iter()
202                    .map(|i| &self.schema.fields()[*i].name)
203                    .collect::<HashSet<_>>();
204                if pg_pk_names.len() != sink_pk_names.len() {
205                    return Err(SinkError::Config(anyhow!(
206                        "Primary key mismatch: Postgres table has primary key on columns {:?}, but sink schema defines primary key on columns {:?}",
207                        pg_pk_names,
208                        sink_pk_names
209                    )));
210                }
211                for name in pg_pk_names {
212                    if !sink_pk_names.contains(name) {
213                        return Err(SinkError::Config(anyhow!(
214                            "Primary key mismatch: Postgres table has primary key on column `{}`, but sink schema does not define it as a primary key",
215                            name
216                        )));
217                    }
218                }
219            }
220        }
221
222        Ok(())
223    }
224
225    async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
226        PostgresSinkWriter::new(
227            self.config.clone(),
228            self.schema.clone(),
229            self.pk_indices.clone(),
230            self.is_append_only,
231        )
232        .await
233    }
234}
235
236struct ParameterBuffer<'a> {
237    /// A set of parameters to be inserted/deleted.
238    /// Each set is a flattened 2d-array.
239    parameters: Vec<Vec<Option<ScalarAdapter>>>,
240    /// the column dimension (fixed).
241    column_length: usize,
242    /// schema types to serialize into `ScalarAdapter`
243    schema_types: &'a [PgType],
244    /// estimated number of parameters that can be sent in a single query.
245    estimated_parameter_size: usize,
246    /// current parameter buffer to be filled.
247    current_parameter_buffer: Vec<Option<ScalarAdapter>>,
248}
249
250impl<'a> ParameterBuffer<'a> {
251    /// The maximum number of parameters that can be sent in a single query.
252    /// See: <https://www.postgresql.org/docs/current/limits.html>
253    /// and <https://github.com/sfackler/rust-postgres/issues/356>
254    const MAX_PARAMETERS: usize = 32768;
255
256    /// `flattened_chunk_size` is the number of datums in a single chunk.
257    fn new(schema_types: &'a [PgType], flattened_chunk_size: usize) -> Self {
258        let estimated_parameter_size = usize::min(Self::MAX_PARAMETERS, flattened_chunk_size);
259        Self {
260            parameters: vec![],
261            column_length: schema_types.len(),
262            schema_types,
263            estimated_parameter_size,
264            current_parameter_buffer: Vec::with_capacity(estimated_parameter_size),
265        }
266    }
267
268    fn add_row(&mut self, row: impl Row) {
269        if self.current_parameter_buffer.len() + self.column_length >= Self::MAX_PARAMETERS {
270            self.new_buffer();
271        }
272        for (i, datum_ref) in row.iter().enumerate() {
273            let pg_datum = datum_ref.map(|s| {
274                let ty = &self.schema_types[i];
275                match ScalarAdapter::from_scalar(s, ty) {
276                    Ok(scalar) => Some(scalar),
277                    Err(e) => {
278                        tracing::error!(error=%e.as_report(), scalar=?s, "Failed to convert scalar to pg value");
279                        None
280                    }
281                }
282            });
283            self.current_parameter_buffer.push(pg_datum.flatten());
284        }
285    }
286
287    fn new_buffer(&mut self) {
288        let filled_buffer = std::mem::replace(
289            &mut self.current_parameter_buffer,
290            Vec::with_capacity(self.estimated_parameter_size),
291        );
292        self.parameters.push(filled_buffer);
293    }
294
295    fn into_parts(self) -> (Vec<Vec<Option<ScalarAdapter>>>, Vec<Option<ScalarAdapter>>) {
296        (self.parameters, self.current_parameter_buffer)
297    }
298}
299
300pub struct PostgresSinkWriter {
301    config: PostgresConfig,
302    pk_indices: Vec<usize>,
303    is_append_only: bool,
304    client: tokio_postgres::Client,
305    schema_types: Vec<PgType>,
306    schema: Schema,
307}
308
309impl PostgresSinkWriter {
310    async fn new(
311        config: PostgresConfig,
312        mut schema: Schema,
313        pk_indices: Vec<usize>,
314        is_append_only: bool,
315    ) -> Result<Self> {
316        let client = create_pg_client(
317            &config.user,
318            &config.password,
319            &config.host,
320            &config.port.to_string(),
321            &config.database,
322            &config.ssl_mode,
323            &config.ssl_root_cert,
324        )
325        .await?;
326
327        // Rewrite schema types for serialization
328        let schema_types = {
329            let name_to_type = PostgresExternalTable::type_mapping(
330                &config.user,
331                &config.password,
332                &config.host,
333                config.port,
334                &config.database,
335                &config.schema,
336                &config.table,
337                &config.ssl_mode,
338                &config.ssl_root_cert,
339                is_append_only,
340            )
341            .await?;
342            let mut schema_types = Vec::with_capacity(schema.fields.len());
343            for field in &mut schema.fields[..] {
344                let field_name = &field.name;
345                let actual_data_type = name_to_type.get(field_name).map(|t| (*t).clone());
346                let actual_data_type = actual_data_type
347                    .ok_or_else(|| {
348                        SinkError::Config(anyhow!(
349                            "Column `{}` not found in sink schema",
350                            field_name
351                        ))
352                    })?
353                    .clone();
354                schema_types.push(actual_data_type);
355            }
356            schema_types
357        };
358
359        let writer = Self {
360            config,
361            pk_indices,
362            is_append_only,
363            client,
364            schema_types,
365            schema,
366        };
367        Ok(writer)
368    }
369
370    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
371        // https://www.postgresql.org/docs/current/limits.html
372        // We have a limit of 65,535 parameters in a single query, as restricted by the PostgreSQL protocol.
373        if self.is_append_only {
374            self.write_batch_append_only(chunk).await
375        } else {
376            self.write_batch_non_append_only(chunk).await
377        }
378    }
379
380    async fn write_batch_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
381        let mut transaction = self.client.transaction().await?;
382        // 1d flattened array of parameters to be inserted.
383        let mut parameter_buffer = ParameterBuffer::new(
384            &self.schema_types,
385            chunk.cardinality() * chunk.data_types().len(),
386        );
387        for (op, row) in chunk.rows() {
388            match op {
389                Op::Insert => {
390                    parameter_buffer.add_row(row);
391                }
392                Op::UpdateInsert | Op::Delete | Op::UpdateDelete => {
393                    bail!(
394                        "append-only sink should not receive update insert, update delete and delete operations"
395                    )
396                }
397            }
398        }
399        let (parameters, remaining) = parameter_buffer.into_parts();
400        Self::execute_parameter(
401            Op::Insert,
402            &mut transaction,
403            &self.schema,
404            &self.config.table,
405            &self.pk_indices,
406            parameters,
407            remaining,
408        )
409        .await?;
410        transaction.commit().await?;
411
412        Ok(())
413    }
414
415    async fn write_batch_non_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
416        let mut transaction = self.client.transaction().await?;
417        // 1d flattened array of parameters to be inserted.
418        let mut insert_parameter_buffer = ParameterBuffer::new(
419            &self.schema_types,
420            chunk.cardinality() * chunk.data_types().len(),
421        );
422        let mut delete_parameter_buffer = ParameterBuffer::new(
423            &self.schema_types,
424            chunk.cardinality() * self.pk_indices.len(),
425        );
426        // 1d flattened array of parameters to be deleted.
427        for (op, row) in chunk.rows() {
428            match op {
429                Op::UpdateInsert | Op::Insert => {
430                    insert_parameter_buffer.add_row(row);
431                }
432                Op::UpdateDelete | Op::Delete => {
433                    delete_parameter_buffer.add_row(row.project(&self.pk_indices));
434                }
435            }
436        }
437
438        let (delete_parameters, delete_remaining_parameter) = delete_parameter_buffer.into_parts();
439        Self::execute_parameter(
440            Op::Delete,
441            &mut transaction,
442            &self.schema,
443            &self.config.table,
444            &self.pk_indices,
445            delete_parameters,
446            delete_remaining_parameter,
447        )
448        .await?;
449        let (insert_parameters, insert_remaining_parameter) = insert_parameter_buffer.into_parts();
450        Self::execute_parameter(
451            Op::Insert,
452            &mut transaction,
453            &self.schema,
454            &self.config.table,
455            &self.pk_indices,
456            insert_parameters,
457            insert_remaining_parameter,
458        )
459        .await?;
460        transaction.commit().await?;
461
462        Ok(())
463    }
464
465    async fn execute_parameter(
466        op: Op,
467        transaction: &mut tokio_postgres::Transaction<'_>,
468        schema: &Schema,
469        table_name: &str,
470        pk_indices: &[usize],
471        parameters: Vec<Vec<Option<ScalarAdapter>>>,
472        remaining_parameter: Vec<Option<ScalarAdapter>>,
473    ) -> Result<()> {
474        let column_length = match op {
475            Op::Insert => schema.len(),
476            Op::Delete => pk_indices.len(),
477            _ => unreachable!(),
478        };
479        if !parameters.is_empty() {
480            let parameter_length = parameters[0].len();
481            let rows_length = parameter_length / column_length;
482            assert_eq!(
483                parameter_length % column_length,
484                0,
485                "flattened parameters are unaligned, parameters={:#?} columns={:#?}",
486                parameters,
487                schema.fields(),
488            );
489            let statement = match op {
490                Op::Insert => create_insert_sql(schema, table_name, rows_length),
491                Op::Delete => create_delete_sql(schema, table_name, pk_indices, rows_length),
492                _ => unreachable!(),
493            };
494            let statement = transaction.prepare(&statement).await?;
495            for parameter in parameters {
496                transaction.execute_raw(&statement, parameter).await?;
497            }
498        }
499        if !remaining_parameter.is_empty() {
500            let rows_length = remaining_parameter.len() / column_length;
501            assert_eq!(
502                remaining_parameter.len() % column_length,
503                0,
504                "flattened parameters are unaligned"
505            );
506            let statement = match op {
507                Op::Insert => create_insert_sql(schema, table_name, rows_length),
508                Op::Delete => create_delete_sql(schema, table_name, pk_indices, rows_length),
509                _ => unreachable!(),
510            };
511            tracing::trace!("binding statement: {:?}", statement);
512            let statement = transaction.prepare(&statement).await?;
513            tracing::trace!("binding parameters: {:?}", remaining_parameter);
514            transaction
515                .execute_raw(&statement, remaining_parameter)
516                .await?;
517        }
518        Ok(())
519    }
520}
521
522#[async_trait]
523impl LogSinker for PostgresSinkWriter {
524    async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
525        log_reader.start_from(None).await?;
526        loop {
527            let (epoch, item) = log_reader.next_item().await?;
528            match item {
529                LogStoreReadItem::StreamChunk { chunk, chunk_id } => {
530                    self.write_batch(chunk).await?;
531                    log_reader.truncate(TruncateOffset::Chunk { epoch, chunk_id })?;
532                }
533                LogStoreReadItem::Barrier { .. } => {
534                    log_reader.truncate(TruncateOffset::Barrier { epoch })?;
535                }
536            }
537        }
538    }
539}
540
541fn create_insert_sql(schema: &Schema, table_name: &str, number_of_rows: usize) -> String {
542    let number_of_columns = schema.len();
543    let columns: String = schema
544        .fields()
545        .iter()
546        .map(|field| field.name.clone())
547        .join(", ");
548    let parameters: String = (0..number_of_rows)
549        .map(|i| {
550            let row_parameters = (0..number_of_columns)
551                .map(|j| format!("${}", i * number_of_columns + j + 1))
552                .join(", ");
553            format!("({row_parameters})")
554        })
555        .collect_vec()
556        .join(", ");
557    format!("INSERT INTO {table_name} ({columns}) VALUES {parameters}")
558}
559
560fn create_delete_sql(
561    schema: &Schema,
562    table_name: &str,
563    pk_indices: &[usize],
564    number_of_rows: usize,
565) -> String {
566    let number_of_pk = pk_indices.len();
567    let pk = {
568        let pk_symbols = pk_indices
569            .iter()
570            .map(|pk_index| &schema.fields()[*pk_index].name)
571            .join(", ");
572        format!("({})", pk_symbols)
573    };
574    let parameters: String = (0..number_of_rows)
575        .map(|i| {
576            let row_parameters: String = (0..pk_indices.len())
577                .map(|j| format!("${}", i * number_of_pk + j + 1))
578                .join(", ");
579            format!("({row_parameters})")
580        })
581        .collect_vec()
582        .join(", ");
583    format!("DELETE FROM {table_name} WHERE {pk} in ({parameters})")
584}
585
586#[cfg(test)]
587mod tests {
588    use std::fmt::Display;
589
590    use expect_test::{Expect, expect};
591    use risingwave_common::catalog::Field;
592    use risingwave_common::types::DataType;
593
594    use super::*;
595
596    fn check(actual: impl Display, expect: Expect) {
597        let actual = actual.to_string();
598        expect.assert_eq(&actual);
599    }
600
601    #[test]
602    fn test_create_insert_sql() {
603        let schema = Schema::new(vec![
604            Field {
605                data_type: DataType::Int32,
606                name: "a".to_owned(),
607            },
608            Field {
609                data_type: DataType::Int32,
610                name: "b".to_owned(),
611            },
612        ]);
613        let table_name = "test_table";
614        let sql = create_insert_sql(&schema, table_name, 3);
615        check(
616            sql,
617            expect!["INSERT INTO test_table (a, b) VALUES ($1, $2), ($3, $4), ($5, $6)"],
618        );
619    }
620
621    #[test]
622    fn test_create_delete_sql() {
623        let schema = Schema::new(vec![
624            Field {
625                data_type: DataType::Int32,
626                name: "a".to_owned(),
627            },
628            Field {
629                data_type: DataType::Int32,
630                name: "b".to_owned(),
631            },
632        ]);
633        let table_name = "test_table";
634        let sql = create_delete_sql(&schema, table_name, &[1], 3);
635        check(
636            sql,
637            expect!["DELETE FROM test_table WHERE (b) in (($1), ($2), ($3))"],
638        );
639        let table_name = "test_table";
640        let sql = create_delete_sql(&schema, table_name, &[0, 1], 3);
641        check(
642            sql,
643            expect!["DELETE FROM test_table WHERE (a, b) in (($1, $2), ($3, $4), ($5, $6))"],
644        );
645    }
646}