1use std::collections::{BTreeMap, HashSet};
16use std::sync::Arc;
17
18use anyhow::{Context, anyhow};
19use async_trait::async_trait;
20use futures::StreamExt;
21use futures::stream::FuturesUnordered;
22use itertools::Itertools;
23use phf::phf_set;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::row::{Row, RowExt};
27use serde::Deserialize;
28use serde_with::{DisplayFromStr, serde_as};
29use simd_json::prelude::ArrayTrait;
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Type as PgType;
32
33use super::{
34 LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
35};
36use crate::connector_common::{PostgresExternalTable, SslMode, create_pg_client};
37use crate::enforce_secret::EnforceSecret;
38use crate::parser::scalar_adapter::{ScalarAdapter, validate_pg_type_to_rw_type};
39use crate::sink::log_store::{LogStoreReadItem, TruncateOffset};
40use crate::sink::{Result, Sink, SinkParam, SinkWriterParam};
41
42pub const POSTGRES_SINK: &str = "postgres";
43
44#[serde_as]
45#[derive(Clone, Debug, Deserialize)]
46pub struct PostgresConfig {
47 pub host: String,
48 #[serde_as(as = "DisplayFromStr")]
49 pub port: u16,
50 pub user: String,
51 pub password: String,
52 pub database: String,
53 pub table: String,
54 #[serde(default = "default_schema")]
55 pub schema: String,
56 #[serde(default = "Default::default")]
57 pub ssl_mode: SslMode,
58 #[serde(rename = "ssl.root.cert")]
59 pub ssl_root_cert: Option<String>,
60 #[serde(default = "default_max_batch_rows")]
61 #[serde_as(as = "DisplayFromStr")]
62 pub max_batch_rows: usize,
63 pub r#type: String, #[serde(default, rename = "tcp.keepalive.enable")]
65 #[serde_as(as = "DisplayFromStr")]
66 pub tcp_keepalive_enable: bool,
67
68 #[serde(flatten)]
69 pub tcp_keepalive: Option<TcpKeepaliveConfig>,
70}
71
72#[serde_as]
73#[derive(Debug, Clone, Deserialize)]
74pub struct TcpKeepaliveConfig {
75 #[serde(rename = "tcp.keepalive.idle")]
76 #[serde_as(as = "DisplayFromStr")]
77 pub tcp_keepalive_idle: u32,
78 #[serde(rename = "tcp.keepalive.interval")]
79 #[serde_as(as = "DisplayFromStr")]
80 pub tcp_keepalive_interval: u32,
81 #[serde(rename = "tcp.keepalive.count")]
82 #[serde_as(as = "DisplayFromStr")]
83 pub tcp_keepalive_count: u32,
84}
85
86impl Default for TcpKeepaliveConfig {
87 fn default() -> Self {
88 Self {
89 tcp_keepalive_idle: 10 * 60, tcp_keepalive_interval: 10,
91 tcp_keepalive_count: 3,
92 }
93 }
94}
95
96impl EnforceSecret for PostgresConfig {
97 const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf_set! {
98 "password", "ssl.root.cert"
99 };
100}
101
102fn default_max_batch_rows() -> usize {
103 1024
104}
105
106fn default_schema() -> String {
107 "public".to_owned()
108}
109
110impl PostgresConfig {
111 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
112 let config =
113 serde_json::from_value::<PostgresConfig>(serde_json::to_value(properties).unwrap())
114 .map_err(|e| SinkError::Config(anyhow!(e)))?;
115 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
116 return Err(SinkError::Config(anyhow!(
117 "`{}` must be {}, or {}",
118 SINK_TYPE_OPTION,
119 SINK_TYPE_APPEND_ONLY,
120 SINK_TYPE_UPSERT
121 )));
122 }
123 Ok(config)
124 }
125}
126
127#[derive(Debug)]
128pub struct PostgresSink {
129 pub config: PostgresConfig,
130 schema: Schema,
131 pk_indices: Vec<usize>,
132 is_append_only: bool,
133}
134
135impl PostgresSink {
136 pub fn new(
137 config: PostgresConfig,
138 schema: Schema,
139 pk_indices: Vec<usize>,
140 is_append_only: bool,
141 ) -> Result<Self> {
142 Ok(Self {
143 config,
144 schema,
145 pk_indices,
146 is_append_only,
147 })
148 }
149}
150
151impl EnforceSecret for PostgresSink {
152 fn enforce_secret<'a>(
153 prop_iter: impl Iterator<Item = &'a str>,
154 ) -> crate::error::ConnectorResult<()> {
155 for prop in prop_iter {
156 PostgresConfig::enforce_one(prop)?;
157 }
158 Ok(())
159 }
160}
161
162impl TryFrom<SinkParam> for PostgresSink {
163 type Error = SinkError;
164
165 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
166 let schema = param.schema();
167 let pk_indices = param.downstream_pk_or_empty();
168 let config = PostgresConfig::from_btreemap(param.properties)?;
169 PostgresSink::new(config, schema, pk_indices, param.sink_type.is_append_only())
170 }
171}
172
173impl Sink for PostgresSink {
174 type LogSinker = PostgresSinkWriter;
175
176 const SINK_NAME: &'static str = POSTGRES_SINK;
177
178 async fn validate(&self) -> Result<()> {
179 if !self.is_append_only && self.pk_indices.is_empty() {
180 return Err(SinkError::Config(anyhow!(
181 "Primary key not defined for upsert Postgres sink (please define in `primary_key` field)"
182 )));
183 }
184
185 {
187 let pg_table = PostgresExternalTable::connect(
188 &self.config.user,
189 &self.config.password,
190 &self.config.host,
191 self.config.port,
192 &self.config.database,
193 &self.config.schema,
194 &self.config.table,
195 &self.config.ssl_mode,
196 &self.config.ssl_root_cert,
197 self.is_append_only,
198 )
199 .await
200 .context(format!(
201 "failed to connect to database: {}, schema: {}, table: {}",
202 &self.config.database, &self.config.schema, &self.config.table
203 ))?;
204
205 {
207 let pg_columns = pg_table.column_descs();
208 let sink_columns = self.schema.fields();
209 if pg_columns.len() < sink_columns.len() {
210 return Err(SinkError::Config(anyhow!(
211 "Column count mismatch: Postgres table has {} columns, but sink schema has {} columns, sink should have less or equal columns to the Postgres table",
212 pg_columns.len(),
213 sink_columns.len()
214 )));
215 }
216
217 let pg_columns_lookup = pg_columns
218 .iter()
219 .map(|c| (c.name.clone(), c.data_type.clone()))
220 .collect::<BTreeMap<_, _>>();
221 for sink_column in sink_columns {
222 let pg_column = pg_columns_lookup.get(&sink_column.name);
223 match pg_column {
224 None => {
225 return Err(SinkError::Config(anyhow!(
226 "Column `{}` not found in Postgres table `{}`",
227 sink_column.name,
228 self.config.table
229 )));
230 }
231 Some(pg_column) => {
232 if !validate_pg_type_to_rw_type(pg_column, &sink_column.data_type()) {
233 return Err(SinkError::Config(anyhow!(
234 "Column `{}` in Postgres table `{}` has type `{}`, but sink schema defines it as type `{}`",
235 sink_column.name,
236 self.config.table,
237 pg_column,
238 sink_column.data_type()
239 )));
240 }
241 }
242 }
243 }
244 }
245
246 {
248 let pg_pk_names = pg_table.pk_names();
249 let sink_pk_names = self
250 .pk_indices
251 .iter()
252 .map(|i| &self.schema.fields()[*i].name)
253 .collect::<HashSet<_>>();
254 if pg_pk_names.len() != sink_pk_names.len() {
255 return Err(SinkError::Config(anyhow!(
256 "Primary key mismatch: Postgres table has primary key on columns {:?}, but sink schema defines primary key on columns {:?}",
257 pg_pk_names,
258 sink_pk_names
259 )));
260 }
261 for name in pg_pk_names {
262 if !sink_pk_names.contains(name) {
263 return Err(SinkError::Config(anyhow!(
264 "Primary key mismatch: Postgres table has primary key on column `{}`, but sink schema does not define it as a primary key",
265 name
266 )));
267 }
268 }
269 }
270 }
271
272 Ok(())
273 }
274
275 async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
276 PostgresSinkWriter::new(
277 self.config.clone(),
278 self.schema.clone(),
279 self.pk_indices.clone(),
280 self.is_append_only,
281 )
282 .await
283 }
284}
285
286pub struct PostgresSinkWriter {
287 is_append_only: bool,
288 client: tokio_postgres::Client,
289 pk_indices: Vec<usize>,
290 pk_types: Vec<PgType>,
291 schema_types: Vec<PgType>,
292 raw_insert_sql: Arc<String>,
293 raw_upsert_sql: Arc<String>,
294 raw_delete_sql: Arc<String>,
295 insert_sql: Arc<tokio_postgres::Statement>,
296 delete_sql: Arc<tokio_postgres::Statement>,
297 upsert_sql: Arc<tokio_postgres::Statement>,
298}
299
300impl PostgresSinkWriter {
301 async fn new(
302 config: PostgresConfig,
303 schema: Schema,
304 pk_indices: Vec<usize>,
305 is_append_only: bool,
306 ) -> Result<Self> {
307 let tcp_keepalive = if config.tcp_keepalive_enable {
308 config
309 .tcp_keepalive
310 .or_else(|| Some(TcpKeepaliveConfig::default()))
311 } else {
312 None
313 };
314
315 let client = create_pg_client(
316 &config.user,
317 &config.password,
318 &config.host,
319 &config.port.to_string(),
320 &config.database,
321 &config.ssl_mode,
322 &config.ssl_root_cert,
323 tcp_keepalive,
324 )
325 .await?;
326
327 let pk_indices_lookup = pk_indices.iter().copied().collect::<HashSet<_>>();
328
329 let (pk_types, schema_types) = {
331 let name_to_type = PostgresExternalTable::type_mapping(
332 &config.user,
333 &config.password,
334 &config.host,
335 config.port,
336 &config.database,
337 &config.schema,
338 &config.table,
339 &config.ssl_mode,
340 &config.ssl_root_cert,
341 is_append_only,
342 )
343 .await?;
344 let mut schema_types = Vec::with_capacity(schema.fields.len());
345 let mut pk_types = Vec::with_capacity(pk_indices.len());
346 for (i, field) in schema.fields.iter().enumerate() {
347 let field_name = &field.name;
348 let actual_data_type = name_to_type.get(field_name).map(|t| (*t).clone());
349 let actual_data_type = actual_data_type
350 .ok_or_else(|| {
351 SinkError::Config(anyhow!(
352 "Column `{}` not found in sink schema",
353 field_name
354 ))
355 })?
356 .clone();
357 if pk_indices_lookup.contains(&i) {
358 pk_types.push(actual_data_type.clone())
359 }
360 schema_types.push(actual_data_type);
361 }
362 (pk_types, schema_types)
363 };
364
365 let raw_insert_sql = create_insert_sql(&schema, &config.schema, &config.table);
366 let raw_upsert_sql = create_upsert_sql(
367 &schema,
368 &config.schema,
369 &config.table,
370 &pk_indices,
371 &pk_indices_lookup,
372 );
373 let raw_delete_sql = create_delete_sql(&schema, &config.schema, &config.table, &pk_indices);
374
375 let insert_sql = client
376 .prepare(&raw_insert_sql)
377 .await
378 .with_context(|| format!("failed to prepare insert statement: {}", raw_insert_sql))?;
379 let upsert_sql = client
380 .prepare(&raw_upsert_sql)
381 .await
382 .with_context(|| format!("failed to prepare upsert statement: {}", raw_upsert_sql))?;
383 let delete_sql = client
384 .prepare(&raw_delete_sql)
385 .await
386 .with_context(|| format!("failed to prepare delete statement: {}", raw_delete_sql))?;
387
388 let writer = Self {
389 is_append_only,
390 client,
391 pk_indices,
392 pk_types,
393 schema_types,
394 raw_insert_sql: Arc::new(raw_insert_sql),
395 raw_upsert_sql: Arc::new(raw_upsert_sql),
396 raw_delete_sql: Arc::new(raw_delete_sql),
397 insert_sql: Arc::new(insert_sql),
398 delete_sql: Arc::new(delete_sql),
399 upsert_sql: Arc::new(upsert_sql),
400 };
401 Ok(writer)
402 }
403
404 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
405 if self.is_append_only {
408 self.write_batch_append_only(chunk).await
409 } else {
410 self.write_batch_non_append_only(chunk).await
411 }
412 }
413
414 async fn write_batch_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
415 let transaction = Arc::new(self.client.transaction().await?);
416 let mut insert_futures = FuturesUnordered::new();
417 for (op, row) in chunk.rows() {
418 match op {
419 Op::Insert => {
420 let pg_row = convert_row_to_pg_row(row, &self.schema_types);
421 let insert_sql = self.insert_sql.clone();
422 let raw_insert_sql = self.raw_insert_sql.clone();
423 let transaction = transaction.clone();
424 let future = async move {
425 transaction
426 .execute_raw(insert_sql.as_ref(), &pg_row)
427 .await
428 .with_context(|| {
429 format!(
430 "failed to execute insert statement: {}, parameters: {:?}",
431 raw_insert_sql, pg_row
432 )
433 })
434 };
435 insert_futures.push(future);
436 }
437 _ => {
438 tracing::error!(
439 "row ignored, append-only sink should not receive update insert, update delete and delete operations"
440 );
441 }
442 }
443 }
444
445 while let Some(result) = insert_futures.next().await {
446 result?;
447 }
448 if let Some(transaction) = Arc::into_inner(transaction) {
449 transaction.commit().await?;
450 } else {
451 tracing::error!("transaction lost!");
452 }
453
454 Ok(())
455 }
456
457 async fn write_batch_non_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
458 let transaction = Arc::new(self.client.transaction().await?);
459 let mut delete_futures = FuturesUnordered::new();
460 let mut upsert_futures = FuturesUnordered::new();
461 for (op, row) in chunk.rows() {
462 match op {
463 Op::Delete | Op::UpdateDelete => {
464 let pg_row =
465 convert_row_to_pg_row(row.project(&self.pk_indices), &self.pk_types);
466 let delete_sql = self.delete_sql.clone();
467 let raw_delete_sql = self.raw_delete_sql.clone();
468 let transaction = transaction.clone();
469 let future = async move {
470 transaction
471 .execute_raw(delete_sql.as_ref(), &pg_row)
472 .await
473 .with_context(|| {
474 format!(
475 "failed to execute delete statement: {}, parameters: {:?}",
476 raw_delete_sql, pg_row
477 )
478 })
479 };
480 delete_futures.push(future);
481 }
482 Op::Insert | Op::UpdateInsert => {
483 let pg_row = convert_row_to_pg_row(row, &self.schema_types);
484 let upsert_sql = self.upsert_sql.clone();
485 let raw_upsert_sql = self.raw_upsert_sql.clone();
486 let transaction = transaction.clone();
487 let future = async move {
488 transaction
489 .execute_raw(upsert_sql.as_ref(), &pg_row)
490 .await
491 .with_context(|| {
492 format!(
493 "failed to execute upsert statement: {}, parameters: {:?}",
494 raw_upsert_sql, pg_row
495 )
496 })
497 };
498 upsert_futures.push(future);
499 }
500 }
501 }
502 while let Some(result) = delete_futures.next().await {
503 result?;
504 }
505 while let Some(result) = upsert_futures.next().await {
506 result?;
507 }
508 if let Some(transaction) = Arc::into_inner(transaction) {
509 transaction.commit().await?;
510 } else {
511 tracing::error!("transaction lost!");
512 }
513 Ok(())
514 }
515}
516
517#[async_trait]
518impl LogSinker for PostgresSinkWriter {
519 async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
520 log_reader.start_from(None).await?;
521 loop {
522 let (epoch, item) = log_reader.next_item().await?;
523 match item {
524 LogStoreReadItem::StreamChunk { chunk, chunk_id } => {
525 self.write_batch(chunk).await?;
526 log_reader.truncate(TruncateOffset::Chunk { epoch, chunk_id })?;
527 }
528 LogStoreReadItem::Barrier { .. } => {
529 log_reader.truncate(TruncateOffset::Barrier { epoch })?;
530 }
531 }
532 }
533 }
534}
535
536fn create_insert_sql(schema: &Schema, schema_name: &str, table_name: &str) -> String {
537 let normalized_table_name = format!(
538 "{}.{}",
539 quote_identifier(schema_name),
540 quote_identifier(table_name)
541 );
542 let number_of_columns = schema.len();
543 let columns: String = schema
544 .fields()
545 .iter()
546 .map(|field| quote_identifier(&field.name))
547 .join(", ");
548 let column_parameters: String = (0..number_of_columns)
549 .map(|i| format!("${}", i + 1))
550 .join(", ");
551 format!("INSERT INTO {normalized_table_name} ({columns}) VALUES ({column_parameters})")
552}
553
554fn create_delete_sql(
555 schema: &Schema,
556 schema_name: &str,
557 table_name: &str,
558 pk_indices: &[usize],
559) -> String {
560 let normalized_table_name = format!(
561 "{}.{}",
562 quote_identifier(schema_name),
563 quote_identifier(table_name)
564 );
565 let pk_indices = if pk_indices.is_empty() {
566 (0..schema.len()).collect_vec()
567 } else {
568 pk_indices.to_vec()
569 };
570 let pk = {
571 let pk_symbols = pk_indices
572 .iter()
573 .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
574 .join(", ");
575 format!("({})", pk_symbols)
576 };
577 let parameters: String = (0..pk_indices.len())
578 .map(|i| format!("${}", i + 1))
579 .join(", ");
580 format!("DELETE FROM {normalized_table_name} WHERE {pk} in (({parameters}))")
581}
582
583fn create_upsert_sql(
584 schema: &Schema,
585 schema_name: &str,
586 table_name: &str,
587 pk_indices: &[usize],
588 pk_indices_lookup: &HashSet<usize>,
589) -> String {
590 let insert_sql = create_insert_sql(schema, schema_name, table_name);
591 if pk_indices.is_empty() {
592 return insert_sql;
593 }
594 let pk_columns = pk_indices
595 .iter()
596 .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
597 .collect_vec()
598 .join(", ");
599 let update_parameters: String = (0..schema.len())
600 .filter(|i| !pk_indices_lookup.contains(i))
601 .map(|i| {
602 let column = quote_identifier(&schema.fields()[i].name);
603 format!("{column} = EXCLUDED.{column}")
604 })
605 .collect_vec()
606 .join(", ");
607 if update_parameters.is_empty() {
608 format!("{insert_sql} on conflict ({pk_columns}) do nothing")
609 } else {
610 format!("{insert_sql} on conflict ({pk_columns}) do update set {update_parameters}")
611 }
612}
613
614fn quote_identifier(identifier: &str) -> String {
616 format!("\"{}\"", identifier.replace("\"", "\"\""))
617}
618
619type PgDatum = Option<ScalarAdapter>;
620type PgRow = Vec<PgDatum>;
621
622fn convert_row_to_pg_row(row: impl Row, schema_types: &[PgType]) -> PgRow {
623 let mut buffer = Vec::with_capacity(row.len());
624 for (i, datum_ref) in row.iter().enumerate() {
625 let pg_datum = datum_ref.map(|s| {
626 match ScalarAdapter::from_scalar(s, &schema_types[i]) {
627 Ok(scalar) => Some(scalar),
628 Err(e) => {
629 tracing::error!(error=%e.as_report(), scalar=?s, "Failed to convert scalar to pg value");
630 None
631 }
632 }
633 });
634 buffer.push(pg_datum.flatten());
635 }
636 buffer
637}
638
639#[cfg(test)]
640mod tests {
641 use std::fmt::Display;
642
643 use expect_test::{Expect, expect};
644 use risingwave_common::catalog::Field;
645 use risingwave_common::types::DataType;
646
647 use super::*;
648
649 fn check(actual: impl Display, expect: Expect) {
650 let actual = actual.to_string();
651 expect.assert_eq(&actual);
652 }
653
654 #[test]
655 fn test_create_insert_sql() {
656 let schema = Schema::new(vec![
657 Field {
658 data_type: DataType::Int32,
659 name: "a".to_owned(),
660 },
661 Field {
662 data_type: DataType::Int32,
663 name: "b".to_owned(),
664 },
665 ]);
666 let schema_name = "test_schema";
667 let table_name = "test_table";
668 let sql = create_insert_sql(&schema, schema_name, table_name);
669 check(
670 sql,
671 expect![[r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2)"#]],
672 );
673 }
674
675 #[test]
676 fn test_create_delete_sql() {
677 let schema = Schema::new(vec![
678 Field {
679 data_type: DataType::Int32,
680 name: "a".to_owned(),
681 },
682 Field {
683 data_type: DataType::Int32,
684 name: "b".to_owned(),
685 },
686 ]);
687 let schema_name = "test_schema";
688 let table_name = "test_table";
689 let sql = create_delete_sql(&schema, schema_name, table_name, &[1]);
690 check(
691 sql,
692 expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("b") in (($1))"#]],
693 );
694 let table_name = "test_table";
695 let sql = create_delete_sql(&schema, schema_name, table_name, &[0, 1]);
696 check(
697 sql,
698 expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("a", "b") in (($1, $2))"#]],
699 );
700 }
701
702 #[test]
703 fn test_create_upsert_sql() {
704 let schema = Schema::new(vec![
705 Field {
706 data_type: DataType::Int32,
707 name: "a".to_owned(),
708 },
709 Field {
710 data_type: DataType::Int32,
711 name: "b".to_owned(),
712 },
713 ]);
714 let schema_name = "test_schema";
715 let table_name = "test_table";
716 let pk_indices_lookup = HashSet::from_iter([1]);
717 let sql = create_upsert_sql(&schema, schema_name, table_name, &[1], &pk_indices_lookup);
718 check(
719 sql,
720 expect![[
721 r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2) on conflict ("b") do update set "a" = EXCLUDED."a""#
722 ]],
723 );
724 }
725
726 #[test]
727 fn test_create_upsert_sql_all_columns_are_primary_keys() {
728 let schema = Schema::new(vec![
729 Field {
730 data_type: DataType::Int32,
731 name: "user_id".to_owned(),
732 },
733 Field {
734 data_type: DataType::Int32,
735 name: "client_id".to_owned(),
736 },
737 ]);
738 let schema_name = "test_schema";
739 let table_name = "test_table";
740 let pk_indices_lookup = HashSet::from_iter([0, 1]);
741 let sql = create_upsert_sql(
742 &schema,
743 schema_name,
744 table_name,
745 &[0, 1],
746 &pk_indices_lookup,
747 );
748 check(
749 sql,
750 expect![[
751 r#"INSERT INTO "test_schema"."test_table" ("user_id", "client_id") VALUES ($1, $2) on conflict ("user_id", "client_id") do nothing"#
752 ]],
753 );
754 }
755}