1use std::collections::{BTreeMap, HashSet};
16use std::sync::Arc;
17
18use anyhow::{Context, anyhow};
19use async_trait::async_trait;
20use futures::StreamExt;
21use futures::stream::FuturesUnordered;
22use itertools::Itertools;
23use phf::phf_set;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::row::{Row, RowExt};
27use serde_derive::Deserialize;
28use serde_with::{DisplayFromStr, serde_as};
29use simd_json::prelude::ArrayTrait;
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Type as PgType;
32
33use super::{
34 LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
35};
36use crate::connector_common::{PostgresExternalTable, SslMode, create_pg_client};
37use crate::enforce_secret::EnforceSecret;
38use crate::parser::scalar_adapter::{ScalarAdapter, validate_pg_type_to_rw_type};
39use crate::sink::log_store::{LogStoreReadItem, TruncateOffset};
40use crate::sink::{Result, Sink, SinkParam, SinkWriterParam};
41
42pub const POSTGRES_SINK: &str = "postgres";
43
44#[serde_as]
45#[derive(Clone, Debug, Deserialize)]
46pub struct PostgresConfig {
47 pub host: String,
48 #[serde_as(as = "DisplayFromStr")]
49 pub port: u16,
50 pub user: String,
51 pub password: String,
52 pub database: String,
53 pub table: String,
54 #[serde(default = "default_schema")]
55 pub schema: String,
56 #[serde(default = "Default::default")]
57 pub ssl_mode: SslMode,
58 #[serde(rename = "ssl.root.cert")]
59 pub ssl_root_cert: Option<String>,
60 #[serde(default = "default_max_batch_rows")]
61 #[serde_as(as = "DisplayFromStr")]
62 pub max_batch_rows: usize,
63 pub r#type: String, }
65
66impl EnforceSecret for PostgresConfig {
67 const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf_set! {
68 "password", "ssl.root.cert"
69 };
70}
71
72fn default_max_batch_rows() -> usize {
73 1024
74}
75
76fn default_schema() -> String {
77 "public".to_owned()
78}
79
80impl PostgresConfig {
81 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
82 let config =
83 serde_json::from_value::<PostgresConfig>(serde_json::to_value(properties).unwrap())
84 .map_err(|e| SinkError::Config(anyhow!(e)))?;
85 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
86 return Err(SinkError::Config(anyhow!(
87 "`{}` must be {}, or {}",
88 SINK_TYPE_OPTION,
89 SINK_TYPE_APPEND_ONLY,
90 SINK_TYPE_UPSERT
91 )));
92 }
93 Ok(config)
94 }
95}
96
97#[derive(Debug)]
98pub struct PostgresSink {
99 pub config: PostgresConfig,
100 schema: Schema,
101 pk_indices: Vec<usize>,
102 is_append_only: bool,
103}
104
105impl PostgresSink {
106 pub fn new(
107 config: PostgresConfig,
108 schema: Schema,
109 pk_indices: Vec<usize>,
110 is_append_only: bool,
111 ) -> Result<Self> {
112 Ok(Self {
113 config,
114 schema,
115 pk_indices,
116 is_append_only,
117 })
118 }
119}
120
121impl EnforceSecret for PostgresSink {
122 fn enforce_secret<'a>(
123 prop_iter: impl Iterator<Item = &'a str>,
124 ) -> crate::error::ConnectorResult<()> {
125 for prop in prop_iter {
126 PostgresConfig::enforce_one(prop)?;
127 }
128 Ok(())
129 }
130}
131
132impl TryFrom<SinkParam> for PostgresSink {
133 type Error = SinkError;
134
135 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
136 let schema = param.schema();
137 let config = PostgresConfig::from_btreemap(param.properties)?;
138 PostgresSink::new(
139 config,
140 schema,
141 param.downstream_pk,
142 param.sink_type.is_append_only(),
143 )
144 }
145}
146
147impl Sink for PostgresSink {
148 type LogSinker = PostgresSinkWriter;
149
150 const SINK_NAME: &'static str = POSTGRES_SINK;
151
152 async fn validate(&self) -> Result<()> {
153 if !self.is_append_only && self.pk_indices.is_empty() {
154 return Err(SinkError::Config(anyhow!(
155 "Primary key not defined for upsert Postgres sink (please define in `primary_key` field)"
156 )));
157 }
158
159 {
161 let pg_table = PostgresExternalTable::connect(
162 &self.config.user,
163 &self.config.password,
164 &self.config.host,
165 self.config.port,
166 &self.config.database,
167 &self.config.schema,
168 &self.config.table,
169 &self.config.ssl_mode,
170 &self.config.ssl_root_cert,
171 self.is_append_only,
172 )
173 .await
174 .context(format!(
175 "failed to connect to database: {}, schema: {}, table: {}",
176 &self.config.database, &self.config.schema, &self.config.table
177 ))?;
178
179 {
181 let pg_columns = pg_table.column_descs();
182 let sink_columns = self.schema.fields();
183 if pg_columns.len() < sink_columns.len() {
184 return Err(SinkError::Config(anyhow!(
185 "Column count mismatch: Postgres table has {} columns, but sink schema has {} columns, sink should have less or equal columns to the Postgres table",
186 pg_columns.len(),
187 sink_columns.len()
188 )));
189 }
190
191 let pg_columns_lookup = pg_columns
192 .iter()
193 .map(|c| (c.name.clone(), c.data_type.clone()))
194 .collect::<BTreeMap<_, _>>();
195 for sink_column in sink_columns {
196 let pg_column = pg_columns_lookup.get(&sink_column.name);
197 match pg_column {
198 None => {
199 return Err(SinkError::Config(anyhow!(
200 "Column `{}` not found in Postgres table `{}`",
201 sink_column.name,
202 self.config.table
203 )));
204 }
205 Some(pg_column) => {
206 if !validate_pg_type_to_rw_type(pg_column, &sink_column.data_type()) {
207 return Err(SinkError::Config(anyhow!(
208 "Column `{}` in Postgres table `{}` has type `{}`, but sink schema defines it as type `{}`",
209 sink_column.name,
210 self.config.table,
211 pg_column,
212 sink_column.data_type()
213 )));
214 }
215 }
216 }
217 }
218 }
219
220 {
222 let pg_pk_names = pg_table.pk_names();
223 let sink_pk_names = self
224 .pk_indices
225 .iter()
226 .map(|i| &self.schema.fields()[*i].name)
227 .collect::<HashSet<_>>();
228 if pg_pk_names.len() != sink_pk_names.len() {
229 return Err(SinkError::Config(anyhow!(
230 "Primary key mismatch: Postgres table has primary key on columns {:?}, but sink schema defines primary key on columns {:?}",
231 pg_pk_names,
232 sink_pk_names
233 )));
234 }
235 for name in pg_pk_names {
236 if !sink_pk_names.contains(name) {
237 return Err(SinkError::Config(anyhow!(
238 "Primary key mismatch: Postgres table has primary key on column `{}`, but sink schema does not define it as a primary key",
239 name
240 )));
241 }
242 }
243 }
244 }
245
246 Ok(())
247 }
248
249 async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
250 PostgresSinkWriter::new(
251 self.config.clone(),
252 self.schema.clone(),
253 self.pk_indices.clone(),
254 self.is_append_only,
255 )
256 .await
257 }
258}
259
260pub struct PostgresSinkWriter {
261 is_append_only: bool,
262 client: tokio_postgres::Client,
263 pk_indices: Vec<usize>,
264 pk_types: Vec<PgType>,
265 schema_types: Vec<PgType>,
266 raw_insert_sql: Arc<String>,
267 raw_upsert_sql: Arc<String>,
268 raw_delete_sql: Arc<String>,
269 insert_sql: Arc<tokio_postgres::Statement>,
270 delete_sql: Arc<tokio_postgres::Statement>,
271 upsert_sql: Arc<tokio_postgres::Statement>,
272}
273
274impl PostgresSinkWriter {
275 async fn new(
276 config: PostgresConfig,
277 schema: Schema,
278 pk_indices: Vec<usize>,
279 is_append_only: bool,
280 ) -> Result<Self> {
281 let client = create_pg_client(
282 &config.user,
283 &config.password,
284 &config.host,
285 &config.port.to_string(),
286 &config.database,
287 &config.ssl_mode,
288 &config.ssl_root_cert,
289 )
290 .await?;
291
292 let pk_indices_lookup = pk_indices.iter().copied().collect::<HashSet<_>>();
293
294 let (pk_types, schema_types) = {
296 let name_to_type = PostgresExternalTable::type_mapping(
297 &config.user,
298 &config.password,
299 &config.host,
300 config.port,
301 &config.database,
302 &config.schema,
303 &config.table,
304 &config.ssl_mode,
305 &config.ssl_root_cert,
306 is_append_only,
307 )
308 .await?;
309 let mut schema_types = Vec::with_capacity(schema.fields.len());
310 let mut pk_types = Vec::with_capacity(pk_indices.len());
311 for (i, field) in schema.fields.iter().enumerate() {
312 let field_name = &field.name;
313 let actual_data_type = name_to_type.get(field_name).map(|t| (*t).clone());
314 let actual_data_type = actual_data_type
315 .ok_or_else(|| {
316 SinkError::Config(anyhow!(
317 "Column `{}` not found in sink schema",
318 field_name
319 ))
320 })?
321 .clone();
322 if pk_indices_lookup.contains(&i) {
323 pk_types.push(actual_data_type.clone())
324 }
325 schema_types.push(actual_data_type);
326 }
327 (pk_types, schema_types)
328 };
329
330 let raw_insert_sql = create_insert_sql(&schema, &config.schema, &config.table);
331 let raw_upsert_sql = create_upsert_sql(
332 &schema,
333 &config.schema,
334 &config.table,
335 &pk_indices,
336 &pk_indices_lookup,
337 );
338 let raw_delete_sql = create_delete_sql(&schema, &config.schema, &config.table, &pk_indices);
339
340 let insert_sql = client
341 .prepare(&raw_insert_sql)
342 .await
343 .with_context(|| format!("failed to prepare insert statement: {}", raw_insert_sql))?;
344 let upsert_sql = client
345 .prepare(&raw_upsert_sql)
346 .await
347 .with_context(|| format!("failed to prepare upsert statement: {}", raw_upsert_sql))?;
348 let delete_sql = client
349 .prepare(&raw_delete_sql)
350 .await
351 .with_context(|| format!("failed to prepare delete statement: {}", raw_delete_sql))?;
352
353 let writer = Self {
354 is_append_only,
355 client,
356 pk_indices,
357 pk_types,
358 schema_types,
359 raw_insert_sql: Arc::new(raw_insert_sql),
360 raw_upsert_sql: Arc::new(raw_upsert_sql),
361 raw_delete_sql: Arc::new(raw_delete_sql),
362 insert_sql: Arc::new(insert_sql),
363 delete_sql: Arc::new(delete_sql),
364 upsert_sql: Arc::new(upsert_sql),
365 };
366 Ok(writer)
367 }
368
369 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
370 if self.is_append_only {
373 self.write_batch_append_only(chunk).await
374 } else {
375 self.write_batch_non_append_only(chunk).await
376 }
377 }
378
379 async fn write_batch_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
380 let transaction = Arc::new(self.client.transaction().await?);
381 let mut insert_futures = FuturesUnordered::new();
382 for (op, row) in chunk.rows() {
383 match op {
384 Op::Insert => {
385 let pg_row = convert_row_to_pg_row(row, &self.schema_types);
386 let insert_sql = self.insert_sql.clone();
387 let raw_insert_sql = self.raw_insert_sql.clone();
388 let transaction = transaction.clone();
389 let future = async move {
390 transaction
391 .execute_raw(insert_sql.as_ref(), &pg_row)
392 .await
393 .with_context(|| {
394 format!(
395 "failed to execute insert statement: {}, parameters: {:?}",
396 raw_insert_sql, pg_row
397 )
398 })
399 };
400 insert_futures.push(future);
401 }
402 _ => {
403 tracing::error!(
404 "row ignored, append-only sink should not receive update insert, update delete and delete operations"
405 );
406 }
407 }
408 }
409
410 while let Some(result) = insert_futures.next().await {
411 result?;
412 }
413 if let Some(transaction) = Arc::into_inner(transaction) {
414 transaction.commit().await?;
415 } else {
416 tracing::error!("transaction lost!");
417 }
418
419 Ok(())
420 }
421
422 async fn write_batch_non_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
423 let transaction = Arc::new(self.client.transaction().await?);
424 let mut delete_futures = FuturesUnordered::new();
425 let mut upsert_futures = FuturesUnordered::new();
426 for (op, row) in chunk.rows() {
427 match op {
428 Op::Delete | Op::UpdateDelete => {
429 let pg_row =
430 convert_row_to_pg_row(row.project(&self.pk_indices), &self.pk_types);
431 let delete_sql = self.delete_sql.clone();
432 let raw_delete_sql = self.raw_delete_sql.clone();
433 let transaction = transaction.clone();
434 let future = async move {
435 transaction
436 .execute_raw(delete_sql.as_ref(), &pg_row)
437 .await
438 .with_context(|| {
439 format!(
440 "failed to execute delete statement: {}, parameters: {:?}",
441 raw_delete_sql, pg_row
442 )
443 })
444 };
445 delete_futures.push(future);
446 }
447 Op::Insert | Op::UpdateInsert => {
448 let pg_row = convert_row_to_pg_row(row, &self.schema_types);
449 let upsert_sql = self.upsert_sql.clone();
450 let raw_upsert_sql = self.raw_upsert_sql.clone();
451 let transaction = transaction.clone();
452 let future = async move {
453 transaction
454 .execute_raw(upsert_sql.as_ref(), &pg_row)
455 .await
456 .with_context(|| {
457 format!(
458 "failed to execute upsert statement: {}, parameters: {:?}",
459 raw_upsert_sql, pg_row
460 )
461 })
462 };
463 upsert_futures.push(future);
464 }
465 }
466 }
467 while let Some(result) = delete_futures.next().await {
468 result?;
469 }
470 while let Some(result) = upsert_futures.next().await {
471 result?;
472 }
473 if let Some(transaction) = Arc::into_inner(transaction) {
474 transaction.commit().await?;
475 } else {
476 tracing::error!("transaction lost!");
477 }
478 Ok(())
479 }
480}
481
482#[async_trait]
483impl LogSinker for PostgresSinkWriter {
484 async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
485 log_reader.start_from(None).await?;
486 loop {
487 let (epoch, item) = log_reader.next_item().await?;
488 match item {
489 LogStoreReadItem::StreamChunk { chunk, chunk_id } => {
490 self.write_batch(chunk).await?;
491 log_reader.truncate(TruncateOffset::Chunk { epoch, chunk_id })?;
492 }
493 LogStoreReadItem::Barrier { .. } => {
494 log_reader.truncate(TruncateOffset::Barrier { epoch })?;
495 }
496 }
497 }
498 }
499}
500
501fn create_insert_sql(schema: &Schema, schema_name: &str, table_name: &str) -> String {
502 let normalized_table_name = format!(
503 "{}.{}",
504 quote_identifier(schema_name),
505 quote_identifier(table_name)
506 );
507 let number_of_columns = schema.len();
508 let columns: String = schema
509 .fields()
510 .iter()
511 .map(|field| quote_identifier(&field.name))
512 .join(", ");
513 let column_parameters: String = (0..number_of_columns)
514 .map(|i| format!("${}", i + 1))
515 .join(", ");
516 format!("INSERT INTO {normalized_table_name} ({columns}) VALUES ({column_parameters})")
517}
518
519fn create_delete_sql(
520 schema: &Schema,
521 schema_name: &str,
522 table_name: &str,
523 pk_indices: &[usize],
524) -> String {
525 let normalized_table_name = format!(
526 "{}.{}",
527 quote_identifier(schema_name),
528 quote_identifier(table_name)
529 );
530 let pk_indices = if pk_indices.is_empty() {
531 (0..schema.len()).collect_vec()
532 } else {
533 pk_indices.to_vec()
534 };
535 let pk = {
536 let pk_symbols = pk_indices
537 .iter()
538 .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
539 .join(", ");
540 format!("({})", pk_symbols)
541 };
542 let parameters: String = (0..pk_indices.len())
543 .map(|i| format!("${}", i + 1))
544 .join(", ");
545 format!("DELETE FROM {normalized_table_name} WHERE {pk} in (({parameters}))")
546}
547
548fn create_upsert_sql(
549 schema: &Schema,
550 schema_name: &str,
551 table_name: &str,
552 pk_indices: &[usize],
553 pk_indices_lookup: &HashSet<usize>,
554) -> String {
555 let insert_sql = create_insert_sql(schema, schema_name, table_name);
556 if pk_indices.is_empty() {
557 return insert_sql;
558 }
559 let pk_columns = pk_indices
560 .iter()
561 .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
562 .collect_vec()
563 .join(", ");
564 let update_parameters: String = (0..schema.len())
565 .filter(|i| !pk_indices_lookup.contains(i))
566 .map(|i| {
567 let column = quote_identifier(&schema.fields()[i].name);
568 format!("{column} = EXCLUDED.{column}")
569 })
570 .collect_vec()
571 .join(", ");
572 format!("{insert_sql} on conflict ({pk_columns}) do update set {update_parameters}")
573}
574
575fn quote_identifier(identifier: &str) -> String {
577 format!("\"{}\"", identifier.replace("\"", "\"\""))
578}
579
580type PgDatum = Option<ScalarAdapter>;
581type PgRow = Vec<PgDatum>;
582
583fn convert_row_to_pg_row(row: impl Row, schema_types: &[PgType]) -> PgRow {
584 let mut buffer = Vec::with_capacity(row.len());
585 for (i, datum_ref) in row.iter().enumerate() {
586 let pg_datum = datum_ref.map(|s| {
587 match ScalarAdapter::from_scalar(s, &schema_types[i]) {
588 Ok(scalar) => Some(scalar),
589 Err(e) => {
590 tracing::error!(error=%e.as_report(), scalar=?s, "Failed to convert scalar to pg value");
591 None
592 }
593 }
594 });
595 buffer.push(pg_datum.flatten());
596 }
597 buffer
598}
599
600#[cfg(test)]
601mod tests {
602 use std::fmt::Display;
603
604 use expect_test::{Expect, expect};
605 use risingwave_common::catalog::Field;
606 use risingwave_common::types::DataType;
607
608 use super::*;
609
610 fn check(actual: impl Display, expect: Expect) {
611 let actual = actual.to_string();
612 expect.assert_eq(&actual);
613 }
614
615 #[test]
616 fn test_create_insert_sql() {
617 let schema = Schema::new(vec![
618 Field {
619 data_type: DataType::Int32,
620 name: "a".to_owned(),
621 },
622 Field {
623 data_type: DataType::Int32,
624 name: "b".to_owned(),
625 },
626 ]);
627 let schema_name = "test_schema";
628 let table_name = "test_table";
629 let sql = create_insert_sql(&schema, schema_name, table_name);
630 check(
631 sql,
632 expect![[r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2)"#]],
633 );
634 }
635
636 #[test]
637 fn test_create_delete_sql() {
638 let schema = Schema::new(vec![
639 Field {
640 data_type: DataType::Int32,
641 name: "a".to_owned(),
642 },
643 Field {
644 data_type: DataType::Int32,
645 name: "b".to_owned(),
646 },
647 ]);
648 let schema_name = "test_schema";
649 let table_name = "test_table";
650 let sql = create_delete_sql(&schema, schema_name, table_name, &[1]);
651 check(
652 sql,
653 expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("b") in (($1))"#]],
654 );
655 let table_name = "test_table";
656 let sql = create_delete_sql(&schema, schema_name, table_name, &[0, 1]);
657 check(
658 sql,
659 expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("a", "b") in (($1, $2))"#]],
660 );
661 }
662
663 #[test]
664 fn test_create_upsert_sql() {
665 let schema = Schema::new(vec![
666 Field {
667 data_type: DataType::Int32,
668 name: "a".to_owned(),
669 },
670 Field {
671 data_type: DataType::Int32,
672 name: "b".to_owned(),
673 },
674 ]);
675 let schema_name = "test_schema";
676 let table_name = "test_table";
677 let pk_indices_lookup = HashSet::from_iter([1]);
678 let sql = create_upsert_sql(&schema, schema_name, table_name, &[1], &pk_indices_lookup);
679 check(
680 sql,
681 expect![[
682 r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2) on conflict ("b") do update set "a" = EXCLUDED."a""#
683 ]],
684 );
685 }
686}