1use std::collections::{BTreeMap, HashSet};
16
17use anyhow::anyhow;
18use async_trait::async_trait;
19use itertools::Itertools;
20use risingwave_common::array::{Op, StreamChunk};
21use risingwave_common::bail;
22use risingwave_common::catalog::Schema;
23use risingwave_common::row::{Row, RowExt};
24use serde_derive::Deserialize;
25use serde_with::{DisplayFromStr, serde_as};
26use simd_json::prelude::ArrayTrait;
27use thiserror_ext::AsReport;
28use tokio_postgres::types::Type as PgType;
29
30use super::{
31 LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
32};
33use crate::connector_common::{PostgresExternalTable, SslMode, create_pg_client};
34use crate::parser::scalar_adapter::{ScalarAdapter, validate_pg_type_to_rw_type};
35use crate::sink::log_store::{LogStoreReadItem, TruncateOffset};
36use crate::sink::{DummySinkCommitCoordinator, Result, Sink, SinkParam, SinkWriterParam};
37
38pub const POSTGRES_SINK: &str = "postgres";
39
40#[serde_as]
41#[derive(Clone, Debug, Deserialize)]
42pub struct PostgresConfig {
43 pub host: String,
44 #[serde_as(as = "DisplayFromStr")]
45 pub port: u16,
46 pub user: String,
47 pub password: String,
48 pub database: String,
49 pub table: String,
50 #[serde(default = "default_schema")]
51 pub schema: String,
52 #[serde(default = "Default::default")]
53 pub ssl_mode: SslMode,
54 #[serde(rename = "ssl.root.cert")]
55 pub ssl_root_cert: Option<String>,
56 #[serde(default = "default_max_batch_rows")]
57 #[serde_as(as = "DisplayFromStr")]
58 pub max_batch_rows: usize,
59 pub r#type: String, }
61
62fn default_max_batch_rows() -> usize {
63 1024
64}
65
66fn default_schema() -> String {
67 "public".to_owned()
68}
69
70impl PostgresConfig {
71 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
72 let config =
73 serde_json::from_value::<PostgresConfig>(serde_json::to_value(properties).unwrap())
74 .map_err(|e| SinkError::Config(anyhow!(e)))?;
75 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
76 return Err(SinkError::Config(anyhow!(
77 "`{}` must be {}, or {}",
78 SINK_TYPE_OPTION,
79 SINK_TYPE_APPEND_ONLY,
80 SINK_TYPE_UPSERT
81 )));
82 }
83 Ok(config)
84 }
85}
86
87#[derive(Debug)]
88pub struct PostgresSink {
89 pub config: PostgresConfig,
90 schema: Schema,
91 pk_indices: Vec<usize>,
92 is_append_only: bool,
93}
94
95impl PostgresSink {
96 pub fn new(
97 config: PostgresConfig,
98 schema: Schema,
99 pk_indices: Vec<usize>,
100 is_append_only: bool,
101 ) -> Result<Self> {
102 Ok(Self {
103 config,
104 schema,
105 pk_indices,
106 is_append_only,
107 })
108 }
109}
110
111impl TryFrom<SinkParam> for PostgresSink {
112 type Error = SinkError;
113
114 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
115 let schema = param.schema();
116 let config = PostgresConfig::from_btreemap(param.properties)?;
117 PostgresSink::new(
118 config,
119 schema,
120 param.downstream_pk,
121 param.sink_type.is_append_only(),
122 )
123 }
124}
125
126impl Sink for PostgresSink {
127 type Coordinator = DummySinkCommitCoordinator;
128 type LogSinker = PostgresSinkWriter;
129
130 const SINK_NAME: &'static str = POSTGRES_SINK;
131
132 async fn validate(&self) -> Result<()> {
133 if !self.is_append_only && self.pk_indices.is_empty() {
134 return Err(SinkError::Config(anyhow!(
135 "Primary key not defined for upsert Postgres sink (please define in `primary_key` field)"
136 )));
137 }
138
139 {
141 let pg_table = PostgresExternalTable::connect(
142 &self.config.user,
143 &self.config.password,
144 &self.config.host,
145 self.config.port,
146 &self.config.database,
147 &self.config.schema,
148 &self.config.table,
149 &self.config.ssl_mode,
150 &self.config.ssl_root_cert,
151 self.is_append_only,
152 )
153 .await?;
154
155 {
157 let pg_columns = pg_table.column_descs();
158 let sink_columns = self.schema.fields();
159 if pg_columns.len() < sink_columns.len() {
160 return Err(SinkError::Config(anyhow!(
161 "Column count mismatch: Postgres table has {} columns, but sink schema has {} columns, sink should have less or equal columns to the Postgres table",
162 pg_columns.len(),
163 sink_columns.len()
164 )));
165 }
166
167 let pg_columns_lookup = pg_columns
168 .iter()
169 .map(|c| (c.name.clone(), c.data_type.clone()))
170 .collect::<BTreeMap<_, _>>();
171 for sink_column in sink_columns {
172 let pg_column = pg_columns_lookup.get(&sink_column.name);
173 match pg_column {
174 None => {
175 return Err(SinkError::Config(anyhow!(
176 "Column `{}` not found in Postgres table `{}`",
177 sink_column.name,
178 self.config.table
179 )));
180 }
181 Some(pg_column) => {
182 if !validate_pg_type_to_rw_type(pg_column, &sink_column.data_type()) {
183 return Err(SinkError::Config(anyhow!(
184 "Column `{}` in Postgres table `{}` has type `{}`, but sink schema defines it as type `{}`",
185 sink_column.name,
186 self.config.table,
187 pg_column,
188 sink_column.data_type()
189 )));
190 }
191 }
192 }
193 }
194 }
195
196 {
198 let pg_pk_names = pg_table.pk_names();
199 let sink_pk_names = self
200 .pk_indices
201 .iter()
202 .map(|i| &self.schema.fields()[*i].name)
203 .collect::<HashSet<_>>();
204 if pg_pk_names.len() != sink_pk_names.len() {
205 return Err(SinkError::Config(anyhow!(
206 "Primary key mismatch: Postgres table has primary key on columns {:?}, but sink schema defines primary key on columns {:?}",
207 pg_pk_names,
208 sink_pk_names
209 )));
210 }
211 for name in pg_pk_names {
212 if !sink_pk_names.contains(name) {
213 return Err(SinkError::Config(anyhow!(
214 "Primary key mismatch: Postgres table has primary key on column `{}`, but sink schema does not define it as a primary key",
215 name
216 )));
217 }
218 }
219 }
220 }
221
222 Ok(())
223 }
224
225 async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
226 PostgresSinkWriter::new(
227 self.config.clone(),
228 self.schema.clone(),
229 self.pk_indices.clone(),
230 self.is_append_only,
231 )
232 .await
233 }
234}
235
236struct ParameterBuffer<'a> {
237 parameters: Vec<Vec<Option<ScalarAdapter>>>,
240 column_length: usize,
242 schema_types: &'a [PgType],
244 estimated_parameter_size: usize,
246 current_parameter_buffer: Vec<Option<ScalarAdapter>>,
248}
249
250impl<'a> ParameterBuffer<'a> {
251 const MAX_PARAMETERS: usize = 32768;
255
256 fn new(schema_types: &'a [PgType], flattened_chunk_size: usize) -> Self {
258 let estimated_parameter_size = usize::min(Self::MAX_PARAMETERS, flattened_chunk_size);
259 Self {
260 parameters: vec![],
261 column_length: schema_types.len(),
262 schema_types,
263 estimated_parameter_size,
264 current_parameter_buffer: Vec::with_capacity(estimated_parameter_size),
265 }
266 }
267
268 fn add_row(&mut self, row: impl Row) {
269 if self.current_parameter_buffer.len() + self.column_length >= Self::MAX_PARAMETERS {
270 self.new_buffer();
271 }
272 for (i, datum_ref) in row.iter().enumerate() {
273 let pg_datum = datum_ref.map(|s| {
274 let ty = &self.schema_types[i];
275 match ScalarAdapter::from_scalar(s, ty) {
276 Ok(scalar) => Some(scalar),
277 Err(e) => {
278 tracing::error!(error=%e.as_report(), scalar=?s, "Failed to convert scalar to pg value");
279 None
280 }
281 }
282 });
283 self.current_parameter_buffer.push(pg_datum.flatten());
284 }
285 }
286
287 fn new_buffer(&mut self) {
288 let filled_buffer = std::mem::replace(
289 &mut self.current_parameter_buffer,
290 Vec::with_capacity(self.estimated_parameter_size),
291 );
292 self.parameters.push(filled_buffer);
293 }
294
295 fn into_parts(self) -> (Vec<Vec<Option<ScalarAdapter>>>, Vec<Option<ScalarAdapter>>) {
296 (self.parameters, self.current_parameter_buffer)
297 }
298}
299
300pub struct PostgresSinkWriter {
301 config: PostgresConfig,
302 pk_indices: Vec<usize>,
303 is_append_only: bool,
304 client: tokio_postgres::Client,
305 schema_types: Vec<PgType>,
306 schema: Schema,
307}
308
309impl PostgresSinkWriter {
310 async fn new(
311 config: PostgresConfig,
312 mut schema: Schema,
313 pk_indices: Vec<usize>,
314 is_append_only: bool,
315 ) -> Result<Self> {
316 let client = create_pg_client(
317 &config.user,
318 &config.password,
319 &config.host,
320 &config.port.to_string(),
321 &config.database,
322 &config.ssl_mode,
323 &config.ssl_root_cert,
324 )
325 .await?;
326
327 let schema_types = {
329 let name_to_type = PostgresExternalTable::type_mapping(
330 &config.user,
331 &config.password,
332 &config.host,
333 config.port,
334 &config.database,
335 &config.schema,
336 &config.table,
337 &config.ssl_mode,
338 &config.ssl_root_cert,
339 is_append_only,
340 )
341 .await?;
342 let mut schema_types = Vec::with_capacity(schema.fields.len());
343 for field in &mut schema.fields[..] {
344 let field_name = &field.name;
345 let actual_data_type = name_to_type.get(field_name).map(|t| (*t).clone());
346 let actual_data_type = actual_data_type
347 .ok_or_else(|| {
348 SinkError::Config(anyhow!(
349 "Column `{}` not found in sink schema",
350 field_name
351 ))
352 })?
353 .clone();
354 schema_types.push(actual_data_type);
355 }
356 schema_types
357 };
358
359 let writer = Self {
360 config,
361 pk_indices,
362 is_append_only,
363 client,
364 schema_types,
365 schema,
366 };
367 Ok(writer)
368 }
369
370 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
371 if self.is_append_only {
374 self.write_batch_append_only(chunk).await
375 } else {
376 self.write_batch_non_append_only(chunk).await
377 }
378 }
379
380 async fn write_batch_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
381 let mut transaction = self.client.transaction().await?;
382 let mut parameter_buffer = ParameterBuffer::new(
384 &self.schema_types,
385 chunk.cardinality() * chunk.data_types().len(),
386 );
387 for (op, row) in chunk.rows() {
388 match op {
389 Op::Insert => {
390 parameter_buffer.add_row(row);
391 }
392 Op::UpdateInsert | Op::Delete | Op::UpdateDelete => {
393 bail!(
394 "append-only sink should not receive update insert, update delete and delete operations"
395 )
396 }
397 }
398 }
399 let (parameters, remaining) = parameter_buffer.into_parts();
400 Self::execute_parameter(
401 Op::Insert,
402 &mut transaction,
403 &self.schema,
404 &self.config.table,
405 &self.pk_indices,
406 parameters,
407 remaining,
408 )
409 .await?;
410 transaction.commit().await?;
411
412 Ok(())
413 }
414
415 async fn write_batch_non_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
416 let mut transaction = self.client.transaction().await?;
417 let mut insert_parameter_buffer = ParameterBuffer::new(
419 &self.schema_types,
420 chunk.cardinality() * chunk.data_types().len(),
421 );
422 let mut delete_parameter_buffer = ParameterBuffer::new(
423 &self.schema_types,
424 chunk.cardinality() * self.pk_indices.len(),
425 );
426 for (op, row) in chunk.rows() {
428 match op {
429 Op::UpdateInsert | Op::Insert => {
430 insert_parameter_buffer.add_row(row);
431 }
432 Op::UpdateDelete | Op::Delete => {
433 delete_parameter_buffer.add_row(row.project(&self.pk_indices));
434 }
435 }
436 }
437
438 let (delete_parameters, delete_remaining_parameter) = delete_parameter_buffer.into_parts();
439 Self::execute_parameter(
440 Op::Delete,
441 &mut transaction,
442 &self.schema,
443 &self.config.table,
444 &self.pk_indices,
445 delete_parameters,
446 delete_remaining_parameter,
447 )
448 .await?;
449 let (insert_parameters, insert_remaining_parameter) = insert_parameter_buffer.into_parts();
450 Self::execute_parameter(
451 Op::Insert,
452 &mut transaction,
453 &self.schema,
454 &self.config.table,
455 &self.pk_indices,
456 insert_parameters,
457 insert_remaining_parameter,
458 )
459 .await?;
460 transaction.commit().await?;
461
462 Ok(())
463 }
464
465 async fn execute_parameter(
466 op: Op,
467 transaction: &mut tokio_postgres::Transaction<'_>,
468 schema: &Schema,
469 table_name: &str,
470 pk_indices: &[usize],
471 parameters: Vec<Vec<Option<ScalarAdapter>>>,
472 remaining_parameter: Vec<Option<ScalarAdapter>>,
473 ) -> Result<()> {
474 let column_length = match op {
475 Op::Insert => schema.len(),
476 Op::Delete => pk_indices.len(),
477 _ => unreachable!(),
478 };
479 if !parameters.is_empty() {
480 let parameter_length = parameters[0].len();
481 let rows_length = parameter_length / column_length;
482 assert_eq!(
483 parameter_length % column_length,
484 0,
485 "flattened parameters are unaligned, parameters={:#?} columns={:#?}",
486 parameters,
487 schema.fields(),
488 );
489 let statement = match op {
490 Op::Insert => create_insert_sql(schema, table_name, rows_length),
491 Op::Delete => create_delete_sql(schema, table_name, pk_indices, rows_length),
492 _ => unreachable!(),
493 };
494 let statement = transaction.prepare(&statement).await?;
495 for parameter in parameters {
496 transaction.execute_raw(&statement, parameter).await?;
497 }
498 }
499 if !remaining_parameter.is_empty() {
500 let rows_length = remaining_parameter.len() / column_length;
501 assert_eq!(
502 remaining_parameter.len() % column_length,
503 0,
504 "flattened parameters are unaligned"
505 );
506 let statement = match op {
507 Op::Insert => create_insert_sql(schema, table_name, rows_length),
508 Op::Delete => create_delete_sql(schema, table_name, pk_indices, rows_length),
509 _ => unreachable!(),
510 };
511 tracing::trace!("binding statement: {:?}", statement);
512 let statement = transaction.prepare(&statement).await?;
513 tracing::trace!("binding parameters: {:?}", remaining_parameter);
514 transaction
515 .execute_raw(&statement, remaining_parameter)
516 .await?;
517 }
518 Ok(())
519 }
520}
521
522#[async_trait]
523impl LogSinker for PostgresSinkWriter {
524 async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
525 log_reader.start_from(None).await?;
526 loop {
527 let (epoch, item) = log_reader.next_item().await?;
528 match item {
529 LogStoreReadItem::StreamChunk { chunk, chunk_id } => {
530 self.write_batch(chunk).await?;
531 log_reader.truncate(TruncateOffset::Chunk { epoch, chunk_id })?;
532 }
533 LogStoreReadItem::Barrier { .. } => {
534 log_reader.truncate(TruncateOffset::Barrier { epoch })?;
535 }
536 }
537 }
538 }
539}
540
541fn create_insert_sql(schema: &Schema, table_name: &str, number_of_rows: usize) -> String {
542 let number_of_columns = schema.len();
543 let columns: String = schema
544 .fields()
545 .iter()
546 .map(|field| field.name.clone())
547 .join(", ");
548 let parameters: String = (0..number_of_rows)
549 .map(|i| {
550 let row_parameters = (0..number_of_columns)
551 .map(|j| format!("${}", i * number_of_columns + j + 1))
552 .join(", ");
553 format!("({row_parameters})")
554 })
555 .collect_vec()
556 .join(", ");
557 format!("INSERT INTO {table_name} ({columns}) VALUES {parameters}")
558}
559
560fn create_delete_sql(
561 schema: &Schema,
562 table_name: &str,
563 pk_indices: &[usize],
564 number_of_rows: usize,
565) -> String {
566 let number_of_pk = pk_indices.len();
567 let pk = {
568 let pk_symbols = pk_indices
569 .iter()
570 .map(|pk_index| &schema.fields()[*pk_index].name)
571 .join(", ");
572 format!("({})", pk_symbols)
573 };
574 let parameters: String = (0..number_of_rows)
575 .map(|i| {
576 let row_parameters: String = (0..pk_indices.len())
577 .map(|j| format!("${}", i * number_of_pk + j + 1))
578 .join(", ");
579 format!("({row_parameters})")
580 })
581 .collect_vec()
582 .join(", ");
583 format!("DELETE FROM {table_name} WHERE {pk} in ({parameters})")
584}
585
586#[cfg(test)]
587mod tests {
588 use std::fmt::Display;
589
590 use expect_test::{Expect, expect};
591 use risingwave_common::catalog::Field;
592 use risingwave_common::types::DataType;
593
594 use super::*;
595
596 fn check(actual: impl Display, expect: Expect) {
597 let actual = actual.to_string();
598 expect.assert_eq(&actual);
599 }
600
601 #[test]
602 fn test_create_insert_sql() {
603 let schema = Schema::new(vec![
604 Field {
605 data_type: DataType::Int32,
606 name: "a".to_owned(),
607 },
608 Field {
609 data_type: DataType::Int32,
610 name: "b".to_owned(),
611 },
612 ]);
613 let table_name = "test_table";
614 let sql = create_insert_sql(&schema, table_name, 3);
615 check(
616 sql,
617 expect!["INSERT INTO test_table (a, b) VALUES ($1, $2), ($3, $4), ($5, $6)"],
618 );
619 }
620
621 #[test]
622 fn test_create_delete_sql() {
623 let schema = Schema::new(vec![
624 Field {
625 data_type: DataType::Int32,
626 name: "a".to_owned(),
627 },
628 Field {
629 data_type: DataType::Int32,
630 name: "b".to_owned(),
631 },
632 ]);
633 let table_name = "test_table";
634 let sql = create_delete_sql(&schema, table_name, &[1], 3);
635 check(
636 sql,
637 expect!["DELETE FROM test_table WHERE (b) in (($1), ($2), ($3))"],
638 );
639 let table_name = "test_table";
640 let sql = create_delete_sql(&schema, table_name, &[0, 1], 3);
641 check(
642 sql,
643 expect!["DELETE FROM test_table WHERE (a, b) in (($1, $2), ($3, $4), ($5, $6))"],
644 );
645 }
646}