1use std::collections::{BTreeMap, HashSet};
16use std::sync::Arc;
17
18use anyhow::{Context, anyhow};
19use async_trait::async_trait;
20use futures::StreamExt;
21use futures::stream::FuturesUnordered;
22use itertools::Itertools;
23use phf::phf_set;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::row::{Row, RowExt};
27use serde::Deserialize;
28use serde_with::{DisplayFromStr, serde_as};
29use simd_json::prelude::ArrayTrait;
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Type as PgType;
32
33use super::{
34 LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
35};
36use crate::connector_common::{PostgresExternalTable, SslMode, create_pg_client};
37use crate::enforce_secret::EnforceSecret;
38use crate::parser::scalar_adapter::{ScalarAdapter, validate_pg_type_to_rw_type};
39use crate::sink::log_store::{LogStoreReadItem, TruncateOffset};
40use crate::sink::{Result, Sink, SinkParam, SinkWriterParam};
41
42pub const POSTGRES_SINK: &str = "postgres";
43
44#[serde_as]
45#[derive(Clone, Debug, Deserialize)]
46pub struct PostgresConfig {
47 pub host: String,
48 #[serde_as(as = "DisplayFromStr")]
49 pub port: u16,
50 pub user: String,
51 pub password: String,
52 pub database: String,
53 pub table: String,
54 #[serde(default = "default_schema")]
55 pub schema: String,
56 #[serde(default = "Default::default")]
57 pub ssl_mode: SslMode,
58 #[serde(rename = "ssl.root.cert")]
59 pub ssl_root_cert: Option<String>,
60 #[serde(default = "default_max_batch_rows")]
61 #[serde_as(as = "DisplayFromStr")]
62 pub max_batch_rows: usize,
63 pub r#type: String, #[serde(default, rename = "tcp.keepalive.enable")]
65 #[serde_as(as = "DisplayFromStr")]
66 pub tcp_keepalive_enable: bool,
67
68 #[serde(flatten)]
69 pub tcp_keepalive: Option<TcpKeepaliveConfig>,
70}
71
72#[serde_as]
73#[derive(Debug, Clone, Deserialize)]
74pub struct TcpKeepaliveConfig {
75 #[serde(rename = "tcp.keepalive.idle")]
76 #[serde_as(as = "DisplayFromStr")]
77 pub tcp_keepalive_idle: u32,
78 #[serde(rename = "tcp.keepalive.interval")]
79 #[serde_as(as = "DisplayFromStr")]
80 pub tcp_keepalive_interval: u32,
81 #[serde(rename = "tcp.keepalive.count")]
82 #[serde_as(as = "DisplayFromStr")]
83 pub tcp_keepalive_count: u32,
84}
85
86impl Default for TcpKeepaliveConfig {
87 fn default() -> Self {
88 Self {
89 tcp_keepalive_idle: 10 * 60, tcp_keepalive_interval: 10,
91 tcp_keepalive_count: 3,
92 }
93 }
94}
95
96impl EnforceSecret for PostgresConfig {
97 const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf_set! {
98 "password", "ssl.root.cert"
99 };
100}
101
102fn default_max_batch_rows() -> usize {
103 1024
104}
105
106fn default_schema() -> String {
107 "public".to_owned()
108}
109
110impl PostgresConfig {
111 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
112 let config =
113 serde_json::from_value::<PostgresConfig>(serde_json::to_value(properties).unwrap())
114 .map_err(|e| SinkError::Config(anyhow!(e)))?;
115 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
116 return Err(SinkError::Config(anyhow!(
117 "`{}` must be {}, or {}",
118 SINK_TYPE_OPTION,
119 SINK_TYPE_APPEND_ONLY,
120 SINK_TYPE_UPSERT
121 )));
122 }
123 Ok(config)
124 }
125}
126
127#[derive(Debug)]
128pub struct PostgresSink {
129 pub config: PostgresConfig,
130 schema: Schema,
131 pk_indices: Vec<usize>,
132 is_append_only: bool,
133}
134
135impl PostgresSink {
136 pub fn new(
137 config: PostgresConfig,
138 schema: Schema,
139 pk_indices: Vec<usize>,
140 is_append_only: bool,
141 ) -> Result<Self> {
142 Ok(Self {
143 config,
144 schema,
145 pk_indices,
146 is_append_only,
147 })
148 }
149}
150
151impl EnforceSecret for PostgresSink {
152 fn enforce_secret<'a>(
153 prop_iter: impl Iterator<Item = &'a str>,
154 ) -> crate::error::ConnectorResult<()> {
155 for prop in prop_iter {
156 PostgresConfig::enforce_one(prop)?;
157 }
158 Ok(())
159 }
160}
161
162impl TryFrom<SinkParam> for PostgresSink {
163 type Error = SinkError;
164
165 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
166 let schema = param.schema();
167 let pk_indices = param.downstream_pk_or_empty();
168 let config = PostgresConfig::from_btreemap(param.properties)?;
169 PostgresSink::new(config, schema, pk_indices, param.sink_type.is_append_only())
170 }
171}
172
173impl Sink for PostgresSink {
174 type LogSinker = PostgresSinkWriter;
175
176 const SINK_NAME: &'static str = POSTGRES_SINK;
177
178 async fn validate(&self) -> Result<()> {
179 if !self.is_append_only && self.pk_indices.is_empty() {
180 return Err(SinkError::Config(anyhow!(
181 "Primary key not defined for upsert Postgres sink (please define in `primary_key` field)"
182 )));
183 }
184
185 {
187 let pg_table = PostgresExternalTable::connect(
188 &self.config.user,
189 &self.config.password,
190 &self.config.host,
191 self.config.port,
192 &self.config.database,
193 &self.config.schema,
194 &self.config.table,
195 &self.config.ssl_mode,
196 &self.config.ssl_root_cert,
197 self.is_append_only,
198 )
199 .await
200 .context(format!(
201 "failed to connect to database: {}, schema: {}, table: {}",
202 &self.config.database, &self.config.schema, &self.config.table
203 ))?;
204
205 {
207 let pg_columns = pg_table.column_descs();
208 let sink_columns = self.schema.fields();
209 if pg_columns.len() < sink_columns.len() {
210 return Err(SinkError::Config(anyhow!(
211 "Column count mismatch: Postgres table has {} columns, but sink schema has {} columns, sink should have less or equal columns to the Postgres table",
212 pg_columns.len(),
213 sink_columns.len()
214 )));
215 }
216
217 let pg_columns_lookup = pg_columns
218 .iter()
219 .map(|c| (c.name.clone(), c.data_type.clone()))
220 .collect::<BTreeMap<_, _>>();
221 for sink_column in sink_columns {
222 let pg_column = pg_columns_lookup.get(&sink_column.name);
223 match pg_column {
224 None => {
225 return Err(SinkError::Config(anyhow!(
226 "Column `{}` not found in Postgres table `{}`",
227 sink_column.name,
228 self.config.table
229 )));
230 }
231 Some(pg_column) => {
232 if !validate_pg_type_to_rw_type(pg_column, &sink_column.data_type()) {
233 return Err(SinkError::Config(anyhow!(
234 "Column `{}` in Postgres table `{}` has type `{}`, but sink schema defines it as type `{}`",
235 sink_column.name,
236 self.config.table,
237 pg_column,
238 sink_column.data_type()
239 )));
240 }
241 }
242 }
243 }
244 }
245
246 {
248 let pg_pk_names = pg_table.pk_names();
249 let sink_pk_names = self
250 .pk_indices
251 .iter()
252 .map(|i| &self.schema.fields()[*i].name)
253 .collect::<HashSet<_>>();
254 if pg_pk_names.len() != sink_pk_names.len() {
255 return Err(SinkError::Config(anyhow!(
256 "Primary key mismatch: Postgres table has primary key on columns {:?}, but sink schema defines primary key on columns {:?}",
257 pg_pk_names,
258 sink_pk_names
259 )));
260 }
261 for name in pg_pk_names {
262 if !sink_pk_names.contains(name) {
263 return Err(SinkError::Config(anyhow!(
264 "Primary key mismatch: Postgres table has primary key on column `{}`, but sink schema does not define it as a primary key",
265 name
266 )));
267 }
268 }
269 }
270 }
271
272 Ok(())
273 }
274
275 async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
276 PostgresSinkWriter::new(
277 self.config.clone(),
278 self.schema.clone(),
279 self.pk_indices.clone(),
280 self.is_append_only,
281 )
282 .await
283 }
284}
285
286pub struct PostgresSinkWriter {
287 is_append_only: bool,
288 client: tokio_postgres::Client,
289 pk_indices: Vec<usize>,
290 pk_types: Vec<PgType>,
291 schema_types: Vec<PgType>,
292 raw_insert_sql: Arc<String>,
293 raw_upsert_sql: Arc<String>,
294 raw_delete_sql: Arc<String>,
295 insert_sql: Arc<tokio_postgres::Statement>,
296 delete_sql: Arc<tokio_postgres::Statement>,
297 upsert_sql: Arc<tokio_postgres::Statement>,
298}
299
300impl PostgresSinkWriter {
301 async fn new(
302 config: PostgresConfig,
303 schema: Schema,
304 pk_indices: Vec<usize>,
305 is_append_only: bool,
306 ) -> Result<Self> {
307 let tcp_keepalive = if config.tcp_keepalive_enable {
308 config
309 .tcp_keepalive
310 .or_else(|| Some(TcpKeepaliveConfig::default()))
311 } else {
312 None
313 };
314
315 let client = create_pg_client(
316 &config.user,
317 &config.password,
318 &config.host,
319 &config.port.to_string(),
320 &config.database,
321 &config.ssl_mode,
322 &config.ssl_root_cert,
323 tcp_keepalive,
324 )
325 .await?;
326
327 let pk_indices_lookup = pk_indices.iter().copied().collect::<HashSet<_>>();
328
329 let (pk_types, schema_types) = {
331 let name_to_type = PostgresExternalTable::type_mapping(
332 &config.user,
333 &config.password,
334 &config.host,
335 config.port,
336 &config.database,
337 &config.schema,
338 &config.table,
339 &config.ssl_mode,
340 &config.ssl_root_cert,
341 is_append_only,
342 )
343 .await?;
344 let mut schema_types = Vec::with_capacity(schema.fields.len());
345 let mut pk_types = Vec::with_capacity(pk_indices.len());
346 for (i, field) in schema.fields.iter().enumerate() {
347 let field_name = &field.name;
348 let actual_data_type = name_to_type.get(field_name).map(|t| (*t).clone());
349 let actual_data_type = actual_data_type
350 .ok_or_else(|| {
351 SinkError::Config(anyhow!(
352 "Column `{}` not found in sink schema",
353 field_name
354 ))
355 })?
356 .clone();
357 if pk_indices_lookup.contains(&i) {
358 pk_types.push(actual_data_type.clone())
359 }
360 schema_types.push(actual_data_type);
361 }
362 (pk_types, schema_types)
363 };
364
365 let raw_insert_sql = create_insert_sql(&schema, &config.schema, &config.table);
366 let raw_upsert_sql = create_upsert_sql(
367 &schema,
368 &config.schema,
369 &config.table,
370 &pk_indices,
371 &pk_indices_lookup,
372 );
373 let raw_delete_sql = create_delete_sql(&schema, &config.schema, &config.table, &pk_indices);
374
375 let insert_sql = client
376 .prepare(&raw_insert_sql)
377 .await
378 .with_context(|| format!("failed to prepare insert statement: {}", raw_insert_sql))?;
379 let upsert_sql = client
380 .prepare(&raw_upsert_sql)
381 .await
382 .with_context(|| format!("failed to prepare upsert statement: {}", raw_upsert_sql))?;
383 let delete_sql = client
384 .prepare(&raw_delete_sql)
385 .await
386 .with_context(|| format!("failed to prepare delete statement: {}", raw_delete_sql))?;
387
388 let writer = Self {
389 is_append_only,
390 client,
391 pk_indices,
392 pk_types,
393 schema_types,
394 raw_insert_sql: Arc::new(raw_insert_sql),
395 raw_upsert_sql: Arc::new(raw_upsert_sql),
396 raw_delete_sql: Arc::new(raw_delete_sql),
397 insert_sql: Arc::new(insert_sql),
398 delete_sql: Arc::new(delete_sql),
399 upsert_sql: Arc::new(upsert_sql),
400 };
401 Ok(writer)
402 }
403
404 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
405 if self.is_append_only {
408 self.write_batch_append_only(chunk).await
409 } else {
410 self.write_batch_non_append_only(chunk).await
411 }
412 }
413
414 async fn write_batch_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
415 let transaction = Arc::new(self.client.transaction().await?);
416 let mut insert_futures = FuturesUnordered::new();
417 for (op, row) in chunk.rows() {
418 match op {
419 Op::Insert => {
420 let pg_row = convert_row_to_pg_row(row, &self.schema_types);
421 let insert_sql = self.insert_sql.clone();
422 let raw_insert_sql = self.raw_insert_sql.clone();
423 let transaction = transaction.clone();
424 let future = async move {
425 transaction
426 .execute_raw(insert_sql.as_ref(), &pg_row)
427 .await
428 .with_context(|| {
429 format!(
430 "failed to execute insert statement: {}, parameters: {:?}",
431 raw_insert_sql, pg_row
432 )
433 })
434 };
435 insert_futures.push(future);
436 }
437 _ => {
438 tracing::error!(
439 "row ignored, append-only sink should not receive update insert, update delete and delete operations"
440 );
441 }
442 }
443 }
444
445 while let Some(result) = insert_futures.next().await {
446 result?;
447 }
448 if let Some(transaction) = Arc::into_inner(transaction) {
449 transaction.commit().await?;
450 } else {
451 tracing::error!("transaction lost!");
452 }
453
454 Ok(())
455 }
456
457 async fn write_batch_non_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
458 let transaction = Arc::new(self.client.transaction().await?);
459 let mut delete_futures = FuturesUnordered::new();
460 let mut upsert_futures = FuturesUnordered::new();
461 for (op, row) in chunk.rows() {
462 match op {
463 Op::Delete | Op::UpdateDelete => {
464 let pg_row =
465 convert_row_to_pg_row(row.project(&self.pk_indices), &self.pk_types);
466 let delete_sql = self.delete_sql.clone();
467 let raw_delete_sql = self.raw_delete_sql.clone();
468 let transaction = transaction.clone();
469 let future = async move {
470 transaction
471 .execute_raw(delete_sql.as_ref(), &pg_row)
472 .await
473 .with_context(|| {
474 format!(
475 "failed to execute delete statement: {}, parameters: {:?}",
476 raw_delete_sql, pg_row
477 )
478 })
479 };
480 delete_futures.push(future);
481 }
482 Op::Insert | Op::UpdateInsert => {
483 let pg_row = convert_row_to_pg_row(row, &self.schema_types);
484 let upsert_sql = self.upsert_sql.clone();
485 let raw_upsert_sql = self.raw_upsert_sql.clone();
486 let transaction = transaction.clone();
487 let future = async move {
488 transaction
489 .execute_raw(upsert_sql.as_ref(), &pg_row)
490 .await
491 .with_context(|| {
492 format!(
493 "failed to execute upsert statement: {}, parameters: {:?}",
494 raw_upsert_sql, pg_row
495 )
496 })
497 };
498 upsert_futures.push(future);
499 }
500 }
501 }
502 while let Some(result) = delete_futures.next().await {
503 result?;
504 }
505 while let Some(result) = upsert_futures.next().await {
506 result?;
507 }
508 if let Some(transaction) = Arc::into_inner(transaction) {
509 transaction.commit().await?;
510 } else {
511 tracing::error!("transaction lost!");
512 }
513 Ok(())
514 }
515}
516
517#[async_trait]
518impl LogSinker for PostgresSinkWriter {
519 async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
520 log_reader.start_from(None).await?;
521 loop {
522 let (epoch, item) = log_reader.next_item().await?;
523 match item {
524 LogStoreReadItem::StreamChunk { chunk, chunk_id } => {
525 self.write_batch(chunk).await?;
526 log_reader.truncate(TruncateOffset::Chunk { epoch, chunk_id })?;
527 }
528 LogStoreReadItem::Barrier { .. } => {
529 log_reader.truncate(TruncateOffset::Barrier { epoch })?;
530 }
531 }
532 }
533 }
534}
535
536fn create_insert_sql(schema: &Schema, schema_name: &str, table_name: &str) -> String {
537 let normalized_table_name = format!(
538 "{}.{}",
539 quote_identifier(schema_name),
540 quote_identifier(table_name)
541 );
542 let number_of_columns = schema.len();
543 let columns: String = schema
544 .fields()
545 .iter()
546 .map(|field| quote_identifier(&field.name))
547 .join(", ");
548 let column_parameters: String = (0..number_of_columns)
549 .map(|i| format!("${}", i + 1))
550 .join(", ");
551 format!("INSERT INTO {normalized_table_name} ({columns}) VALUES ({column_parameters})")
552}
553
554fn create_delete_sql(
555 schema: &Schema,
556 schema_name: &str,
557 table_name: &str,
558 pk_indices: &[usize],
559) -> String {
560 let normalized_table_name = format!(
561 "{}.{}",
562 quote_identifier(schema_name),
563 quote_identifier(table_name)
564 );
565 let pk_indices = if pk_indices.is_empty() {
566 (0..schema.len()).collect_vec()
567 } else {
568 pk_indices.to_vec()
569 };
570 let pk = {
571 let pk_symbols = pk_indices
572 .iter()
573 .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
574 .join(", ");
575 format!("({})", pk_symbols)
576 };
577 let parameters: String = (0..pk_indices.len())
578 .map(|i| format!("${}", i + 1))
579 .join(", ");
580 format!("DELETE FROM {normalized_table_name} WHERE {pk} in (({parameters}))")
581}
582
583fn create_upsert_sql(
584 schema: &Schema,
585 schema_name: &str,
586 table_name: &str,
587 pk_indices: &[usize],
588 pk_indices_lookup: &HashSet<usize>,
589) -> String {
590 let insert_sql = create_insert_sql(schema, schema_name, table_name);
591 if pk_indices.is_empty() {
592 return insert_sql;
593 }
594 let pk_columns = pk_indices
595 .iter()
596 .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
597 .collect_vec()
598 .join(", ");
599 let update_parameters: String = (0..schema.len())
600 .filter(|i| !pk_indices_lookup.contains(i))
601 .map(|i| {
602 let column = quote_identifier(&schema.fields()[i].name);
603 format!("{column} = EXCLUDED.{column}")
604 })
605 .collect_vec()
606 .join(", ");
607 format!("{insert_sql} on conflict ({pk_columns}) do update set {update_parameters}")
608}
609
610fn quote_identifier(identifier: &str) -> String {
612 format!("\"{}\"", identifier.replace("\"", "\"\""))
613}
614
615type PgDatum = Option<ScalarAdapter>;
616type PgRow = Vec<PgDatum>;
617
618fn convert_row_to_pg_row(row: impl Row, schema_types: &[PgType]) -> PgRow {
619 let mut buffer = Vec::with_capacity(row.len());
620 for (i, datum_ref) in row.iter().enumerate() {
621 let pg_datum = datum_ref.map(|s| {
622 match ScalarAdapter::from_scalar(s, &schema_types[i]) {
623 Ok(scalar) => Some(scalar),
624 Err(e) => {
625 tracing::error!(error=%e.as_report(), scalar=?s, "Failed to convert scalar to pg value");
626 None
627 }
628 }
629 });
630 buffer.push(pg_datum.flatten());
631 }
632 buffer
633}
634
635#[cfg(test)]
636mod tests {
637 use std::fmt::Display;
638
639 use expect_test::{Expect, expect};
640 use risingwave_common::catalog::Field;
641 use risingwave_common::types::DataType;
642
643 use super::*;
644
645 fn check(actual: impl Display, expect: Expect) {
646 let actual = actual.to_string();
647 expect.assert_eq(&actual);
648 }
649
650 #[test]
651 fn test_create_insert_sql() {
652 let schema = Schema::new(vec![
653 Field {
654 data_type: DataType::Int32,
655 name: "a".to_owned(),
656 },
657 Field {
658 data_type: DataType::Int32,
659 name: "b".to_owned(),
660 },
661 ]);
662 let schema_name = "test_schema";
663 let table_name = "test_table";
664 let sql = create_insert_sql(&schema, schema_name, table_name);
665 check(
666 sql,
667 expect![[r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2)"#]],
668 );
669 }
670
671 #[test]
672 fn test_create_delete_sql() {
673 let schema = Schema::new(vec![
674 Field {
675 data_type: DataType::Int32,
676 name: "a".to_owned(),
677 },
678 Field {
679 data_type: DataType::Int32,
680 name: "b".to_owned(),
681 },
682 ]);
683 let schema_name = "test_schema";
684 let table_name = "test_table";
685 let sql = create_delete_sql(&schema, schema_name, table_name, &[1]);
686 check(
687 sql,
688 expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("b") in (($1))"#]],
689 );
690 let table_name = "test_table";
691 let sql = create_delete_sql(&schema, schema_name, table_name, &[0, 1]);
692 check(
693 sql,
694 expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("a", "b") in (($1, $2))"#]],
695 );
696 }
697
698 #[test]
699 fn test_create_upsert_sql() {
700 let schema = Schema::new(vec![
701 Field {
702 data_type: DataType::Int32,
703 name: "a".to_owned(),
704 },
705 Field {
706 data_type: DataType::Int32,
707 name: "b".to_owned(),
708 },
709 ]);
710 let schema_name = "test_schema";
711 let table_name = "test_table";
712 let pk_indices_lookup = HashSet::from_iter([1]);
713 let sql = create_upsert_sql(&schema, schema_name, table_name, &[1], &pk_indices_lookup);
714 check(
715 sql,
716 expect![[
717 r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2) on conflict ("b") do update set "a" = EXCLUDED."a""#
718 ]],
719 );
720 }
721}