1use std::collections::{BTreeMap, HashSet};
16use std::sync::Arc;
17
18use anyhow::{Context, anyhow};
19use async_trait::async_trait;
20use futures::StreamExt;
21use futures::stream::FuturesUnordered;
22use itertools::Itertools;
23use phf::phf_set;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::row::{Row, RowExt};
27use serde_derive::Deserialize;
28use serde_with::{DisplayFromStr, serde_as};
29use simd_json::prelude::ArrayTrait;
30use thiserror_ext::AsReport;
31use tokio_postgres::types::Type as PgType;
32
33use super::{
34 LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
35};
36use crate::connector_common::{PostgresExternalTable, SslMode, create_pg_client};
37use crate::enforce_secret::EnforceSecret;
38use crate::parser::scalar_adapter::{ScalarAdapter, validate_pg_type_to_rw_type};
39use crate::sink::log_store::{LogStoreReadItem, TruncateOffset};
40use crate::sink::{DummySinkCommitCoordinator, Result, Sink, SinkParam, SinkWriterParam};
41
42pub const POSTGRES_SINK: &str = "postgres";
43
44#[serde_as]
45#[derive(Clone, Debug, Deserialize)]
46pub struct PostgresConfig {
47 pub host: String,
48 #[serde_as(as = "DisplayFromStr")]
49 pub port: u16,
50 pub user: String,
51 pub password: String,
52 pub database: String,
53 pub table: String,
54 #[serde(default = "default_schema")]
55 pub schema: String,
56 #[serde(default = "Default::default")]
57 pub ssl_mode: SslMode,
58 #[serde(rename = "ssl.root.cert")]
59 pub ssl_root_cert: Option<String>,
60 #[serde(default = "default_max_batch_rows")]
61 #[serde_as(as = "DisplayFromStr")]
62 pub max_batch_rows: usize,
63 pub r#type: String, }
65
66impl EnforceSecret for PostgresConfig {
67 const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf_set! {
68 "password", "ssl.root.cert"
69 };
70}
71
72fn default_max_batch_rows() -> usize {
73 1024
74}
75
76fn default_schema() -> String {
77 "public".to_owned()
78}
79
80impl PostgresConfig {
81 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
82 let config =
83 serde_json::from_value::<PostgresConfig>(serde_json::to_value(properties).unwrap())
84 .map_err(|e| SinkError::Config(anyhow!(e)))?;
85 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
86 return Err(SinkError::Config(anyhow!(
87 "`{}` must be {}, or {}",
88 SINK_TYPE_OPTION,
89 SINK_TYPE_APPEND_ONLY,
90 SINK_TYPE_UPSERT
91 )));
92 }
93 Ok(config)
94 }
95}
96
97#[derive(Debug)]
98pub struct PostgresSink {
99 pub config: PostgresConfig,
100 schema: Schema,
101 pk_indices: Vec<usize>,
102 is_append_only: bool,
103}
104
105impl PostgresSink {
106 pub fn new(
107 config: PostgresConfig,
108 schema: Schema,
109 pk_indices: Vec<usize>,
110 is_append_only: bool,
111 ) -> Result<Self> {
112 Ok(Self {
113 config,
114 schema,
115 pk_indices,
116 is_append_only,
117 })
118 }
119}
120
121impl EnforceSecret for PostgresSink {
122 fn enforce_secret<'a>(
123 prop_iter: impl Iterator<Item = &'a str>,
124 ) -> crate::error::ConnectorResult<()> {
125 for prop in prop_iter {
126 PostgresConfig::enforce_one(prop)?;
127 }
128 Ok(())
129 }
130}
131
132impl TryFrom<SinkParam> for PostgresSink {
133 type Error = SinkError;
134
135 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
136 let schema = param.schema();
137 let config = PostgresConfig::from_btreemap(param.properties)?;
138 PostgresSink::new(
139 config,
140 schema,
141 param.downstream_pk,
142 param.sink_type.is_append_only(),
143 )
144 }
145}
146
147impl Sink for PostgresSink {
148 type Coordinator = DummySinkCommitCoordinator;
149 type LogSinker = PostgresSinkWriter;
150
151 const SINK_NAME: &'static str = POSTGRES_SINK;
152
153 async fn validate(&self) -> Result<()> {
154 if !self.is_append_only && self.pk_indices.is_empty() {
155 return Err(SinkError::Config(anyhow!(
156 "Primary key not defined for upsert Postgres sink (please define in `primary_key` field)"
157 )));
158 }
159
160 {
162 let pg_table = PostgresExternalTable::connect(
163 &self.config.user,
164 &self.config.password,
165 &self.config.host,
166 self.config.port,
167 &self.config.database,
168 &self.config.schema,
169 &self.config.table,
170 &self.config.ssl_mode,
171 &self.config.ssl_root_cert,
172 self.is_append_only,
173 )
174 .await
175 .context(format!(
176 "failed to connect to database: {}, schema: {}, table: {}",
177 &self.config.database, &self.config.schema, &self.config.table
178 ))?;
179
180 {
182 let pg_columns = pg_table.column_descs();
183 let sink_columns = self.schema.fields();
184 if pg_columns.len() < sink_columns.len() {
185 return Err(SinkError::Config(anyhow!(
186 "Column count mismatch: Postgres table has {} columns, but sink schema has {} columns, sink should have less or equal columns to the Postgres table",
187 pg_columns.len(),
188 sink_columns.len()
189 )));
190 }
191
192 let pg_columns_lookup = pg_columns
193 .iter()
194 .map(|c| (c.name.clone(), c.data_type.clone()))
195 .collect::<BTreeMap<_, _>>();
196 for sink_column in sink_columns {
197 let pg_column = pg_columns_lookup.get(&sink_column.name);
198 match pg_column {
199 None => {
200 return Err(SinkError::Config(anyhow!(
201 "Column `{}` not found in Postgres table `{}`",
202 sink_column.name,
203 self.config.table
204 )));
205 }
206 Some(pg_column) => {
207 if !validate_pg_type_to_rw_type(pg_column, &sink_column.data_type()) {
208 return Err(SinkError::Config(anyhow!(
209 "Column `{}` in Postgres table `{}` has type `{}`, but sink schema defines it as type `{}`",
210 sink_column.name,
211 self.config.table,
212 pg_column,
213 sink_column.data_type()
214 )));
215 }
216 }
217 }
218 }
219 }
220
221 {
223 let pg_pk_names = pg_table.pk_names();
224 let sink_pk_names = self
225 .pk_indices
226 .iter()
227 .map(|i| &self.schema.fields()[*i].name)
228 .collect::<HashSet<_>>();
229 if pg_pk_names.len() != sink_pk_names.len() {
230 return Err(SinkError::Config(anyhow!(
231 "Primary key mismatch: Postgres table has primary key on columns {:?}, but sink schema defines primary key on columns {:?}",
232 pg_pk_names,
233 sink_pk_names
234 )));
235 }
236 for name in pg_pk_names {
237 if !sink_pk_names.contains(name) {
238 return Err(SinkError::Config(anyhow!(
239 "Primary key mismatch: Postgres table has primary key on column `{}`, but sink schema does not define it as a primary key",
240 name
241 )));
242 }
243 }
244 }
245 }
246
247 Ok(())
248 }
249
250 async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
251 PostgresSinkWriter::new(
252 self.config.clone(),
253 self.schema.clone(),
254 self.pk_indices.clone(),
255 self.is_append_only,
256 )
257 .await
258 }
259}
260
261pub struct PostgresSinkWriter {
262 is_append_only: bool,
263 client: tokio_postgres::Client,
264 pk_indices: Vec<usize>,
265 pk_types: Vec<PgType>,
266 schema_types: Vec<PgType>,
267 raw_insert_sql: Arc<String>,
268 raw_upsert_sql: Arc<String>,
269 raw_delete_sql: Arc<String>,
270 insert_sql: Arc<tokio_postgres::Statement>,
271 delete_sql: Arc<tokio_postgres::Statement>,
272 upsert_sql: Arc<tokio_postgres::Statement>,
273}
274
275impl PostgresSinkWriter {
276 async fn new(
277 config: PostgresConfig,
278 schema: Schema,
279 pk_indices: Vec<usize>,
280 is_append_only: bool,
281 ) -> Result<Self> {
282 let client = create_pg_client(
283 &config.user,
284 &config.password,
285 &config.host,
286 &config.port.to_string(),
287 &config.database,
288 &config.ssl_mode,
289 &config.ssl_root_cert,
290 )
291 .await?;
292
293 let pk_indices_lookup = pk_indices.iter().copied().collect::<HashSet<_>>();
294
295 let (pk_types, schema_types) = {
297 let name_to_type = PostgresExternalTable::type_mapping(
298 &config.user,
299 &config.password,
300 &config.host,
301 config.port,
302 &config.database,
303 &config.schema,
304 &config.table,
305 &config.ssl_mode,
306 &config.ssl_root_cert,
307 is_append_only,
308 )
309 .await?;
310 let mut schema_types = Vec::with_capacity(schema.fields.len());
311 let mut pk_types = Vec::with_capacity(pk_indices.len());
312 for (i, field) in schema.fields.iter().enumerate() {
313 let field_name = &field.name;
314 let actual_data_type = name_to_type.get(field_name).map(|t| (*t).clone());
315 let actual_data_type = actual_data_type
316 .ok_or_else(|| {
317 SinkError::Config(anyhow!(
318 "Column `{}` not found in sink schema",
319 field_name
320 ))
321 })?
322 .clone();
323 if pk_indices_lookup.contains(&i) {
324 pk_types.push(actual_data_type.clone())
325 }
326 schema_types.push(actual_data_type);
327 }
328 (pk_types, schema_types)
329 };
330
331 let raw_insert_sql = create_insert_sql(&schema, &config.schema, &config.table);
332 let raw_upsert_sql = create_upsert_sql(
333 &schema,
334 &config.schema,
335 &config.table,
336 &pk_indices,
337 &pk_indices_lookup,
338 );
339 let raw_delete_sql = create_delete_sql(&schema, &config.schema, &config.table, &pk_indices);
340
341 let insert_sql = client
342 .prepare(&raw_insert_sql)
343 .await
344 .with_context(|| format!("failed to prepare insert statement: {}", raw_insert_sql))?;
345 let upsert_sql = client
346 .prepare(&raw_upsert_sql)
347 .await
348 .with_context(|| format!("failed to prepare upsert statement: {}", raw_upsert_sql))?;
349 let delete_sql = client
350 .prepare(&raw_delete_sql)
351 .await
352 .with_context(|| format!("failed to prepare delete statement: {}", raw_delete_sql))?;
353
354 let writer = Self {
355 is_append_only,
356 client,
357 pk_indices,
358 pk_types,
359 schema_types,
360 raw_insert_sql: Arc::new(raw_insert_sql),
361 raw_upsert_sql: Arc::new(raw_upsert_sql),
362 raw_delete_sql: Arc::new(raw_delete_sql),
363 insert_sql: Arc::new(insert_sql),
364 delete_sql: Arc::new(delete_sql),
365 upsert_sql: Arc::new(upsert_sql),
366 };
367 Ok(writer)
368 }
369
370 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
371 if self.is_append_only {
374 self.write_batch_append_only(chunk).await
375 } else {
376 self.write_batch_non_append_only(chunk).await
377 }
378 }
379
380 async fn write_batch_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
381 let transaction = Arc::new(self.client.transaction().await?);
382 let mut insert_futures = FuturesUnordered::new();
383 for (op, row) in chunk.rows() {
384 match op {
385 Op::Insert => {
386 let pg_row = convert_row_to_pg_row(row, &self.schema_types);
387 let insert_sql = self.insert_sql.clone();
388 let raw_insert_sql = self.raw_insert_sql.clone();
389 let transaction = transaction.clone();
390 let future = async move {
391 transaction
392 .execute_raw(insert_sql.as_ref(), &pg_row)
393 .await
394 .with_context(|| {
395 format!(
396 "failed to execute insert statement: {}, parameters: {:?}",
397 raw_insert_sql, pg_row
398 )
399 })
400 };
401 insert_futures.push(future);
402 }
403 _ => {
404 tracing::error!(
405 "row ignored, append-only sink should not receive update insert, update delete and delete operations"
406 );
407 }
408 }
409 }
410
411 while let Some(result) = insert_futures.next().await {
412 result?;
413 }
414 if let Some(transaction) = Arc::into_inner(transaction) {
415 transaction.commit().await?;
416 } else {
417 tracing::error!("transaction lost!");
418 }
419
420 Ok(())
421 }
422
423 async fn write_batch_non_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
424 let transaction = Arc::new(self.client.transaction().await?);
425 let mut delete_futures = FuturesUnordered::new();
426 let mut upsert_futures = FuturesUnordered::new();
427 for (op, row) in chunk.rows() {
428 match op {
429 Op::Delete | Op::UpdateDelete => {
430 let pg_row =
431 convert_row_to_pg_row(row.project(&self.pk_indices), &self.pk_types);
432 let delete_sql = self.delete_sql.clone();
433 let raw_delete_sql = self.raw_delete_sql.clone();
434 let transaction = transaction.clone();
435 let future = async move {
436 transaction
437 .execute_raw(delete_sql.as_ref(), &pg_row)
438 .await
439 .with_context(|| {
440 format!(
441 "failed to execute delete statement: {}, parameters: {:?}",
442 raw_delete_sql, pg_row
443 )
444 })
445 };
446 delete_futures.push(future);
447 }
448 Op::Insert | Op::UpdateInsert => {
449 let pg_row = convert_row_to_pg_row(row, &self.schema_types);
450 let upsert_sql = self.upsert_sql.clone();
451 let raw_upsert_sql = self.raw_upsert_sql.clone();
452 let transaction = transaction.clone();
453 let future = async move {
454 transaction
455 .execute_raw(upsert_sql.as_ref(), &pg_row)
456 .await
457 .with_context(|| {
458 format!(
459 "failed to execute upsert statement: {}, parameters: {:?}",
460 raw_upsert_sql, pg_row
461 )
462 })
463 };
464 upsert_futures.push(future);
465 }
466 }
467 }
468 while let Some(result) = delete_futures.next().await {
469 result?;
470 }
471 while let Some(result) = upsert_futures.next().await {
472 result?;
473 }
474 if let Some(transaction) = Arc::into_inner(transaction) {
475 transaction.commit().await?;
476 } else {
477 tracing::error!("transaction lost!");
478 }
479 Ok(())
480 }
481}
482
483#[async_trait]
484impl LogSinker for PostgresSinkWriter {
485 async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
486 log_reader.start_from(None).await?;
487 loop {
488 let (epoch, item) = log_reader.next_item().await?;
489 match item {
490 LogStoreReadItem::StreamChunk { chunk, chunk_id } => {
491 self.write_batch(chunk).await?;
492 log_reader.truncate(TruncateOffset::Chunk { epoch, chunk_id })?;
493 }
494 LogStoreReadItem::Barrier { .. } => {
495 log_reader.truncate(TruncateOffset::Barrier { epoch })?;
496 }
497 }
498 }
499 }
500}
501
502fn create_insert_sql(schema: &Schema, schema_name: &str, table_name: &str) -> String {
503 let normalized_table_name = format!(
504 "{}.{}",
505 quote_identifier(schema_name),
506 quote_identifier(table_name)
507 );
508 let number_of_columns = schema.len();
509 let columns: String = schema
510 .fields()
511 .iter()
512 .map(|field| quote_identifier(&field.name))
513 .join(", ");
514 let column_parameters: String = (0..number_of_columns)
515 .map(|i| format!("${}", i + 1))
516 .join(", ");
517 format!("INSERT INTO {normalized_table_name} ({columns}) VALUES ({column_parameters})")
518}
519
520fn create_delete_sql(
521 schema: &Schema,
522 schema_name: &str,
523 table_name: &str,
524 pk_indices: &[usize],
525) -> String {
526 let normalized_table_name = format!(
527 "{}.{}",
528 quote_identifier(schema_name),
529 quote_identifier(table_name)
530 );
531 let pk_indices = if pk_indices.is_empty() {
532 (0..schema.len()).collect_vec()
533 } else {
534 pk_indices.to_vec()
535 };
536 let pk = {
537 let pk_symbols = pk_indices
538 .iter()
539 .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
540 .join(", ");
541 format!("({})", pk_symbols)
542 };
543 let parameters: String = (0..pk_indices.len())
544 .map(|i| format!("${}", i + 1))
545 .join(", ");
546 format!("DELETE FROM {normalized_table_name} WHERE {pk} in (({parameters}))")
547}
548
549fn create_upsert_sql(
550 schema: &Schema,
551 schema_name: &str,
552 table_name: &str,
553 pk_indices: &[usize],
554 pk_indices_lookup: &HashSet<usize>,
555) -> String {
556 let insert_sql = create_insert_sql(schema, schema_name, table_name);
557 if pk_indices.is_empty() {
558 return insert_sql;
559 }
560 let pk_columns = pk_indices
561 .iter()
562 .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
563 .collect_vec()
564 .join(", ");
565 let update_parameters: String = (0..schema.len())
566 .filter(|i| !pk_indices_lookup.contains(i))
567 .map(|i| {
568 let column = quote_identifier(&schema.fields()[i].name);
569 format!("{column} = EXCLUDED.{column}")
570 })
571 .collect_vec()
572 .join(", ");
573 format!("{insert_sql} on conflict ({pk_columns}) do update set {update_parameters}")
574}
575
576fn quote_identifier(identifier: &str) -> String {
578 format!("\"{}\"", identifier.replace("\"", "\"\""))
579}
580
581type PgDatum = Option<ScalarAdapter>;
582type PgRow = Vec<PgDatum>;
583
584fn convert_row_to_pg_row(row: impl Row, schema_types: &[PgType]) -> PgRow {
585 let mut buffer = Vec::with_capacity(row.len());
586 for (i, datum_ref) in row.iter().enumerate() {
587 let pg_datum = datum_ref.map(|s| {
588 match ScalarAdapter::from_scalar(s, &schema_types[i]) {
589 Ok(scalar) => Some(scalar),
590 Err(e) => {
591 tracing::error!(error=%e.as_report(), scalar=?s, "Failed to convert scalar to pg value");
592 None
593 }
594 }
595 });
596 buffer.push(pg_datum.flatten());
597 }
598 buffer
599}
600
601#[cfg(test)]
602mod tests {
603 use std::fmt::Display;
604
605 use expect_test::{Expect, expect};
606 use risingwave_common::catalog::Field;
607 use risingwave_common::types::DataType;
608
609 use super::*;
610
611 fn check(actual: impl Display, expect: Expect) {
612 let actual = actual.to_string();
613 expect.assert_eq(&actual);
614 }
615
616 #[test]
617 fn test_create_insert_sql() {
618 let schema = Schema::new(vec![
619 Field {
620 data_type: DataType::Int32,
621 name: "a".to_owned(),
622 },
623 Field {
624 data_type: DataType::Int32,
625 name: "b".to_owned(),
626 },
627 ]);
628 let schema_name = "test_schema";
629 let table_name = "test_table";
630 let sql = create_insert_sql(&schema, schema_name, table_name);
631 check(
632 sql,
633 expect![[r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2)"#]],
634 );
635 }
636
637 #[test]
638 fn test_create_delete_sql() {
639 let schema = Schema::new(vec![
640 Field {
641 data_type: DataType::Int32,
642 name: "a".to_owned(),
643 },
644 Field {
645 data_type: DataType::Int32,
646 name: "b".to_owned(),
647 },
648 ]);
649 let schema_name = "test_schema";
650 let table_name = "test_table";
651 let sql = create_delete_sql(&schema, schema_name, table_name, &[1]);
652 check(
653 sql,
654 expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("b") in (($1))"#]],
655 );
656 let table_name = "test_table";
657 let sql = create_delete_sql(&schema, schema_name, table_name, &[0, 1]);
658 check(
659 sql,
660 expect![[r#"DELETE FROM "test_schema"."test_table" WHERE ("a", "b") in (($1, $2))"#]],
661 );
662 }
663
664 #[test]
665 fn test_create_upsert_sql() {
666 let schema = Schema::new(vec![
667 Field {
668 data_type: DataType::Int32,
669 name: "a".to_owned(),
670 },
671 Field {
672 data_type: DataType::Int32,
673 name: "b".to_owned(),
674 },
675 ]);
676 let schema_name = "test_schema";
677 let table_name = "test_table";
678 let pk_indices_lookup = HashSet::from_iter([1]);
679 let sql = create_upsert_sql(&schema, schema_name, table_name, &[1], &pk_indices_lookup);
680 check(
681 sql,
682 expect![[
683 r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2) on conflict ("b") do update set "a" = EXCLUDED."a""#
684 ]],
685 );
686 }
687}