1use std::collections::{BTreeMap, HashSet};
16use std::sync::Arc;
17
18use anyhow::{Context, anyhow};
19use async_trait::async_trait;
20use futures::StreamExt;
21use futures::stream::FuturesUnordered;
22use itertools::Itertools;
23use phf::phf_set;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::row::{Row, RowExt};
27use serde::Deserialize;
28use serde_with::{DisplayFromStr, serde_as};
29use simd_json::prelude::ArrayTrait;
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Type as PgType;
32
33use super::{
34 LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
35};
36use crate::connector_common::{PostgresExternalTable, SslMode, create_pg_client};
37use crate::enforce_secret::EnforceSecret;
38use crate::parser::scalar_adapter::{ScalarAdapter, validate_pg_type_to_rw_type};
39use crate::sink::log_store::{LogStoreReadItem, TruncateOffset};
40use crate::sink::{Result, Sink, SinkParam, SinkWriterParam};
41
42pub const POSTGRES_SINK: &str = "postgres";
43
44#[serde_as]
45#[derive(Clone, Debug, Deserialize)]
46pub struct PostgresConfig {
47 pub host: String,
48 #[serde_as(as = "DisplayFromStr")]
49 pub port: u16,
50 pub user: String,
51 pub password: String,
52 pub database: String,
53 pub table: String,
54 #[serde(default = "default_schema")]
55 pub schema: String,
56 #[serde(default = "Default::default")]
57 pub ssl_mode: SslMode,
58 #[serde(rename = "ssl.root.cert")]
59 pub ssl_root_cert: Option<String>,
60 #[serde(default = "default_max_batch_rows")]
61 #[serde_as(as = "DisplayFromStr")]
62 pub max_batch_rows: usize,
63 pub r#type: String, }
65
66impl EnforceSecret for PostgresConfig {
67 const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf_set! {
68 "password", "ssl.root.cert"
69 };
70}
71
72fn default_max_batch_rows() -> usize {
73 1024
74}
75
76fn default_schema() -> String {
77 "public".to_owned()
78}
79
80impl PostgresConfig {
81 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
82 let config =
83 serde_json::from_value::<PostgresConfig>(serde_json::to_value(properties).unwrap())
84 .map_err(|e| SinkError::Config(anyhow!(e)))?;
85 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
86 return Err(SinkError::Config(anyhow!(
87 "`{}` must be {}, or {}",
88 SINK_TYPE_OPTION,
89 SINK_TYPE_APPEND_ONLY,
90 SINK_TYPE_UPSERT
91 )));
92 }
93 Ok(config)
94 }
95}
96
97#[derive(Debug)]
98pub struct PostgresSink {
99 pub config: PostgresConfig,
100 schema: Schema,
101 pk_indices: Vec<usize>,
102 is_append_only: bool,
103}
104
105impl PostgresSink {
106 pub fn new(
107 config: PostgresConfig,
108 schema: Schema,
109 pk_indices: Vec<usize>,
110 is_append_only: bool,
111 ) -> Result<Self> {
112 Ok(Self {
113 config,
114 schema,
115 pk_indices,
116 is_append_only,
117 })
118 }
119}
120
121impl EnforceSecret for PostgresSink {
122 fn enforce_secret<'a>(
123 prop_iter: impl Iterator<Item = &'a str>,
124 ) -> crate::error::ConnectorResult<()> {
125 for prop in prop_iter {
126 PostgresConfig::enforce_one(prop)?;
127 }
128 Ok(())
129 }
130}
131
132impl TryFrom<SinkParam> for PostgresSink {
133 type Error = SinkError;
134
135 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
136 let schema = param.schema();
137 let pk_indices = param.downstream_pk_or_empty();
138 let config = PostgresConfig::from_btreemap(param.properties)?;
139 PostgresSink::new(config, schema, pk_indices, param.sink_type.is_append_only())
140 }
141}
142
143impl Sink for PostgresSink {
144 type LogSinker = PostgresSinkWriter;
145
146 const SINK_NAME: &'static str = POSTGRES_SINK;
147
148 async fn validate(&self) -> Result<()> {
149 if !self.is_append_only && self.pk_indices.is_empty() {
150 return Err(SinkError::Config(anyhow!(
151 "Primary key not defined for upsert Postgres sink (please define in `primary_key` field)"
152 )));
153 }
154
155 {
157 let pg_table = PostgresExternalTable::connect(
158 &self.config.user,
159 &self.config.password,
160 &self.config.host,
161 self.config.port,
162 &self.config.database,
163 &self.config.schema,
164 &self.config.table,
165 &self.config.ssl_mode,
166 &self.config.ssl_root_cert,
167 self.is_append_only,
168 )
169 .await
170 .context(format!(
171 "failed to connect to database: {}, schema: {}, table: {}",
172 &self.config.database, &self.config.schema, &self.config.table
173 ))?;
174
175 {
177 let pg_columns = pg_table.column_descs();
178 let sink_columns = self.schema.fields();
179 if pg_columns.len() < sink_columns.len() {
180 return Err(SinkError::Config(anyhow!(
181 "Column count mismatch: Postgres table has {} columns, but sink schema has {} columns, sink should have less or equal columns to the Postgres table",
182 pg_columns.len(),
183 sink_columns.len()
184 )));
185 }
186
187 let pg_columns_lookup = pg_columns
188 .iter()
189 .map(|c| (c.name.clone(), c.data_type.clone()))
190 .collect::<BTreeMap<_, _>>();
191 for sink_column in sink_columns {
192 let pg_column = pg_columns_lookup.get(&sink_column.name);
193 match pg_column {
194 None => {
195 return Err(SinkError::Config(anyhow!(
196 "Column `{}` not found in Postgres table `{}`",
197 sink_column.name,
198 self.config.table
199 )));
200 }
201 Some(pg_column) => {
202 if !validate_pg_type_to_rw_type(pg_column, &sink_column.data_type()) {
203 return Err(SinkError::Config(anyhow!(
204 "Column `{}` in Postgres table `{}` has type `{}`, but sink schema defines it as type `{}`",
205 sink_column.name,
206 self.config.table,
207 pg_column,
208 sink_column.data_type()
209 )));
210 }
211 }
212 }
213 }
214 }
215
216 {
218 let pg_pk_names = pg_table.pk_names();
219 let sink_pk_names = self
220 .pk_indices
221 .iter()
222 .map(|i| &self.schema.fields()[*i].name)
223 .collect::<HashSet<_>>();
224 if pg_pk_names.len() != sink_pk_names.len() {
225 return Err(SinkError::Config(anyhow!(
226 "Primary key mismatch: Postgres table has primary key on columns {:?}, but sink schema defines primary key on columns {:?}",
227 pg_pk_names,
228 sink_pk_names
229 )));
230 }
231 for name in pg_pk_names {
232 if !sink_pk_names.contains(name) {
233 return Err(SinkError::Config(anyhow!(
234 "Primary key mismatch: Postgres table has primary key on column `{}`, but sink schema does not define it as a primary key",
235 name
236 )));
237 }
238 }
239 }
240 }
241
242 Ok(())
243 }
244
245 async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
246 PostgresSinkWriter::new(
247 self.config.clone(),
248 self.schema.clone(),
249 self.pk_indices.clone(),
250 self.is_append_only,
251 )
252 .await
253 }
254}
255
256pub struct PostgresSinkWriter {
257 is_append_only: bool,
258 client: tokio_postgres::Client,
259 pk_indices: Vec<usize>,
260 pk_types: Vec<PgType>,
261 schema_types: Vec<PgType>,
262 raw_insert_sql: Arc<String>,
263 raw_upsert_sql: Arc<String>,
264 raw_delete_sql: Arc<String>,
265 insert_sql: Arc<tokio_postgres::Statement>,
266 delete_sql: Arc<tokio_postgres::Statement>,
267 upsert_sql: Arc<tokio_postgres::Statement>,
268}
269
270impl PostgresSinkWriter {
271 async fn new(
272 config: PostgresConfig,
273 schema: Schema,
274 pk_indices: Vec<usize>,
275 is_append_only: bool,
276 ) -> Result<Self> {
277 let client = create_pg_client(
278 &config.user,
279 &config.password,
280 &config.host,
281 &config.port.to_string(),
282 &config.database,
283 &config.ssl_mode,
284 &config.ssl_root_cert,
285 )
286 .await?;
287
288 let pk_indices_lookup = pk_indices.iter().copied().collect::<HashSet<_>>();
289
290 let (pk_types, schema_types) = {
292 let name_to_type = PostgresExternalTable::type_mapping(
293 &config.user,
294 &config.password,
295 &config.host,
296 config.port,
297 &config.database,
298 &config.schema,
299 &config.table,
300 &config.ssl_mode,
301 &config.ssl_root_cert,
302 is_append_only,
303 )
304 .await?;
305 let mut schema_types = Vec::with_capacity(schema.fields.len());
306 let mut pk_types = Vec::with_capacity(pk_indices.len());
307 for (i, field) in schema.fields.iter().enumerate() {
308 let field_name = &field.name;
309 let actual_data_type = name_to_type.get(field_name).map(|t| (*t).clone());
310 let actual_data_type = actual_data_type
311 .ok_or_else(|| {
312 SinkError::Config(anyhow!(
313 "Column `{}` not found in sink schema",
314 field_name
315 ))
316 })?
317 .clone();
318 if pk_indices_lookup.contains(&i) {
319 pk_types.push(actual_data_type.clone())
320 }
321 schema_types.push(actual_data_type);
322 }
323 (pk_types, schema_types)
324 };
325
326 let raw_insert_sql = create_insert_sql(&schema, &config.schema, &config.table);
327 let raw_upsert_sql = create_upsert_sql(
328 &schema,
329 &config.schema,
330 &config.table,
331 &pk_indices,
332 &pk_indices_lookup,
333 );
334 let raw_delete_sql = create_delete_sql(&schema, &config.schema, &config.table, &pk_indices);
335
336 let insert_sql = client
337 .prepare(&raw_insert_sql)
338 .await
339 .with_context(|| format!("failed to prepare insert statement: {}", raw_insert_sql))?;
340 let upsert_sql = client
341 .prepare(&raw_upsert_sql)
342 .await
343 .with_context(|| format!("failed to prepare upsert statement: {}", raw_upsert_sql))?;
344 let delete_sql = client
345 .prepare(&raw_delete_sql)
346 .await
347 .with_context(|| format!("failed to prepare delete statement: {}", raw_delete_sql))?;
348
349 let writer = Self {
350 is_append_only,
351 client,
352 pk_indices,
353 pk_types,
354 schema_types,
355 raw_insert_sql: Arc::new(raw_insert_sql),
356 raw_upsert_sql: Arc::new(raw_upsert_sql),
357 raw_delete_sql: Arc::new(raw_delete_sql),
358 insert_sql: Arc::new(insert_sql),
359 delete_sql: Arc::new(delete_sql),
360 upsert_sql: Arc::new(upsert_sql),
361 };
362 Ok(writer)
363 }
364
365 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
366 if self.is_append_only {
369 self.write_batch_append_only(chunk).await
370 } else {
371 self.write_batch_non_append_only(chunk).await
372 }
373 }
374
375 async fn write_batch_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
376 let transaction = Arc::new(self.client.transaction().await?);
377 let mut insert_futures = FuturesUnordered::new();
378 for (op, row) in chunk.rows() {
379 match op {
380 Op::Insert => {
381 let pg_row = convert_row_to_pg_row(row, &self.schema_types);
382 let insert_sql = self.insert_sql.clone();
383 let raw_insert_sql = self.raw_insert_sql.clone();
384 let transaction = transaction.clone();
385 let future = async move {
386 transaction
387 .execute_raw(insert_sql.as_ref(), &pg_row)
388 .await
389 .with_context(|| {
390 format!(
391 "failed to execute insert statement: {}, parameters: {:?}",
392 raw_insert_sql, pg_row
393 )
394 })
395 };
396 insert_futures.push(future);
397 }
398 _ => {
399 tracing::error!(
400 "row ignored, append-only sink should not receive update insert, update delete and delete operations"
401 );
402 }
403 }
404 }
405
406 while let Some(result) = insert_futures.next().await {
407 result?;
408 }
409 if let Some(transaction) = Arc::into_inner(transaction) {
410 transaction.commit().await?;
411 } else {
412 tracing::error!("transaction lost!");
413 }
414
415 Ok(())
416 }
417
418 async fn write_batch_non_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
419 let transaction = Arc::new(self.client.transaction().await?);
420 let mut delete_futures = FuturesUnordered::new();
421 let mut upsert_futures = FuturesUnordered::new();
422 for (op, row) in chunk.rows() {
423 match op {
424 Op::Delete | Op::UpdateDelete => {
425 let pg_row =
426 convert_row_to_pg_row(row.project(&self.pk_indices), &self.pk_types);
427 let delete_sql = self.delete_sql.clone();
428 let raw_delete_sql = self.raw_delete_sql.clone();
429 let transaction = transaction.clone();
430 let future = async move {
431 transaction
432 .execute_raw(delete_sql.as_ref(), &pg_row)
433 .await
434 .with_context(|| {
435 format!(
436 "failed to execute delete statement: {}, parameters: {:?}",
437 raw_delete_sql, pg_row
438 )
439 })
440 };
441 delete_futures.push(future);
442 }
443 Op::Insert | Op::UpdateInsert => {
444 let pg_row = convert_row_to_pg_row(row, &self.schema_types);
445 let upsert_sql = self.upsert_sql.clone();
446 let raw_upsert_sql = self.raw_upsert_sql.clone();
447 let transaction = transaction.clone();
448 let future = async move {
449 transaction
450 .execute_raw(upsert_sql.as_ref(), &pg_row)
451 .await
452 .with_context(|| {
453 format!(
454 "failed to execute upsert statement: {}, parameters: {:?}",
455 raw_upsert_sql, pg_row
456 )
457 })
458 };
459 upsert_futures.push(future);
460 }
461 }
462 }
463 while let Some(result) = delete_futures.next().await {
464 result?;
465 }
466 while let Some(result) = upsert_futures.next().await {
467 result?;
468 }
469 if let Some(transaction) = Arc::into_inner(transaction) {
470 transaction.commit().await?;
471 } else {
472 tracing::error!("transaction lost!");
473 }
474 Ok(())
475 }
476}
477
478#[async_trait]
479impl LogSinker for PostgresSinkWriter {
480 async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
481 log_reader.start_from(None).await?;
482 loop {
483 let (epoch, item) = log_reader.next_item().await?;
484 match item {
485 LogStoreReadItem::StreamChunk { chunk, chunk_id } => {
486 self.write_batch(chunk).await?;
487 log_reader.truncate(TruncateOffset::Chunk { epoch, chunk_id })?;
488 }
489 LogStoreReadItem::Barrier { .. } => {
490 log_reader.truncate(TruncateOffset::Barrier { epoch })?;
491 }
492 }
493 }
494 }
495}
496
497fn create_insert_sql(schema: &Schema, schema_name: &str, table_name: &str) -> String {
498 let normalized_table_name = format!(
499 "{}.{}",
500 quote_identifier(schema_name),
501 quote_identifier(table_name)
502 );
503 let number_of_columns = schema.len();
504 let columns: String = schema
505 .fields()
506 .iter()
507 .map(|field| quote_identifier(&field.name))
508 .join(", ");
509 let column_parameters: String = (0..number_of_columns)
510 .map(|i| format!("${}", i + 1))
511 .join(", ");
512 format!("INSERT INTO {normalized_table_name} ({columns}) VALUES ({column_parameters})")
513}
514
515fn create_delete_sql(
516 schema: &Schema,
517 schema_name: &str,
518 table_name: &str,
519 pk_indices: &[usize],
520) -> String {
521 let normalized_table_name = format!(
522 "{}.{}",
523 quote_identifier(schema_name),
524 quote_identifier(table_name)
525 );
526 let pk_indices = if pk_indices.is_empty() {
527 (0..schema.len()).collect_vec()
528 } else {
529 pk_indices.to_vec()
530 };
531 let pk = {
532 let pk_symbols = pk_indices
533 .iter()
534 .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
535 .join(", ");
536 format!("({})", pk_symbols)
537 };
538 let parameters: String = (0..pk_indices.len())
539 .map(|i| format!("${}", i + 1))
540 .join(", ");
541 format!("DELETE FROM {normalized_table_name} WHERE {pk} in (({parameters}))")
542}
543
544fn create_upsert_sql(
545 schema: &Schema,
546 schema_name: &str,
547 table_name: &str,
548 pk_indices: &[usize],
549 pk_indices_lookup: &HashSet<usize>,
550) -> String {
551 let insert_sql = create_insert_sql(schema, schema_name, table_name);
552 if pk_indices.is_empty() {
553 return insert_sql;
554 }
555 let pk_columns = pk_indices
556 .iter()
557 .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
558 .collect_vec()
559 .join(", ");
560 let update_parameters: String = (0..schema.len())
561 .filter(|i| !pk_indices_lookup.contains(i))
562 .map(|i| {
563 let column = quote_identifier(&schema.fields()[i].name);
564 format!("{column} = EXCLUDED.{column}")
565 })
566 .collect_vec()
567 .join(", ");
568 format!("{insert_sql} on conflict ({pk_columns}) do update set {update_parameters}")
569}
570
571fn quote_identifier(identifier: &str) -> String {
573 format!("\"{}\"", identifier.replace("\"", "\"\""))
574}
575
576type PgDatum = Option<ScalarAdapter>;
577type PgRow = Vec<PgDatum>;
578
579fn convert_row_to_pg_row(row: impl Row, schema_types: &[PgType]) -> PgRow {
580 let mut buffer = Vec::with_capacity(row.len());
581 for (i, datum_ref) in row.iter().enumerate() {
582 let pg_datum = datum_ref.map(|s| {
583 match ScalarAdapter::from_scalar(s, &schema_types[i]) {
584 Ok(scalar) => Some(scalar),
585 Err(e) => {
586 tracing::error!(error=%e.as_report(), scalar=?s, "Failed to convert scalar to pg value");
587 None
588 }
589 }
590 });
591 buffer.push(pg_datum.flatten());
592 }
593 buffer
594}
595
596#[cfg(test)]
597mod tests {
598 use std::fmt::Display;
599
600 use expect_test::{Expect, expect};
601 use risingwave_common::catalog::Field;
602 use risingwave_common::types::DataType;
603
604 use super::*;
605
606 fn check(actual: impl Display, expect: Expect) {
607 let actual = actual.to_string();
608 expect.assert_eq(&actual);
609 }
610
611 #[test]
612 fn test_create_insert_sql() {
613 let schema = Schema::new(vec![
614 Field {
615 data_type: DataType::Int32,
616 name: "a".to_owned(),
617 },
618 Field {
619 data_type: DataType::Int32,
620 name: "b".to_owned(),
621 },
622 ]);
623 let schema_name = "test_schema";
624 let table_name = "test_table";
625 let sql = create_insert_sql(&schema, schema_name, table_name);
626 check(
627 sql,
628 expect![[r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2)"#]],
629 );
630 }
631
632 #[test]
633 fn test_create_delete_sql() {
634 let schema = Schema::new(vec![
635 Field {
636 data_type: DataType::Int32,
637 name: "a".to_owned(),
638 },
639 Field {
640 data_type: DataType::Int32,
641 name: "b".to_owned(),
642 },
643 ]);
644 let schema_name = "test_schema";
645 let table_name = "test_table";
646 let sql = create_delete_sql(&schema, schema_name, table_name, &[1]);
647 check(
648 sql,
649 expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("b") in (($1))"#]],
650 );
651 let table_name = "test_table";
652 let sql = create_delete_sql(&schema, schema_name, table_name, &[0, 1]);
653 check(
654 sql,
655 expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("a", "b") in (($1, $2))"#]],
656 );
657 }
658
659 #[test]
660 fn test_create_upsert_sql() {
661 let schema = Schema::new(vec![
662 Field {
663 data_type: DataType::Int32,
664 name: "a".to_owned(),
665 },
666 Field {
667 data_type: DataType::Int32,
668 name: "b".to_owned(),
669 },
670 ]);
671 let schema_name = "test_schema";
672 let table_name = "test_table";
673 let pk_indices_lookup = HashSet::from_iter([1]);
674 let sql = create_upsert_sql(&schema, schema_name, table_name, &[1], &pk_indices_lookup);
675 check(
676 sql,
677 expect![[
678 r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2) on conflict ("b") do update set "a" = EXCLUDED."a""#
679 ]],
680 );
681 }
682}