risingwave_connector/sink/
postgres.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashSet};
16
17use anyhow::{Context, anyhow};
18use async_trait::async_trait;
19use itertools::Itertools;
20use phf::phf_set;
21use risingwave_common::array::{Op, StreamChunk};
22use risingwave_common::bail;
23use risingwave_common::catalog::Schema;
24use risingwave_common::row::{Row, RowExt};
25use serde_derive::Deserialize;
26use serde_with::{DisplayFromStr, serde_as};
27use simd_json::prelude::ArrayTrait;
28use thiserror_ext::AsReport;
29use tokio_postgres::types::Type as PgType;
30
31use super::{
32    LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
33};
34use crate::connector_common::{PostgresExternalTable, SslMode, create_pg_client};
35use crate::enforce_secret::EnforceSecret;
36use crate::parser::scalar_adapter::{ScalarAdapter, validate_pg_type_to_rw_type};
37use crate::sink::log_store::{LogStoreReadItem, TruncateOffset};
38use crate::sink::{DummySinkCommitCoordinator, Result, Sink, SinkParam, SinkWriterParam};
39
40pub const POSTGRES_SINK: &str = "postgres";
41
42#[serde_as]
43#[derive(Clone, Debug, Deserialize)]
44pub struct PostgresConfig {
45    pub host: String,
46    #[serde_as(as = "DisplayFromStr")]
47    pub port: u16,
48    pub user: String,
49    pub password: String,
50    pub database: String,
51    pub table: String,
52    #[serde(default = "default_schema")]
53    pub schema: String,
54    #[serde(default = "Default::default")]
55    pub ssl_mode: SslMode,
56    #[serde(rename = "ssl.root.cert")]
57    pub ssl_root_cert: Option<String>,
58    #[serde(default = "default_max_batch_rows")]
59    #[serde_as(as = "DisplayFromStr")]
60    pub max_batch_rows: usize,
61    pub r#type: String, // accept "append-only" or "upsert"
62}
63
64impl EnforceSecret for PostgresConfig {
65    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf_set! {
66        "password", "ssl.root.cert"
67    };
68}
69
70fn default_max_batch_rows() -> usize {
71    1024
72}
73
74fn default_schema() -> String {
75    "public".to_owned()
76}
77
78impl PostgresConfig {
79    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
80        let config =
81            serde_json::from_value::<PostgresConfig>(serde_json::to_value(properties).unwrap())
82                .map_err(|e| SinkError::Config(anyhow!(e)))?;
83        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
84            return Err(SinkError::Config(anyhow!(
85                "`{}` must be {}, or {}",
86                SINK_TYPE_OPTION,
87                SINK_TYPE_APPEND_ONLY,
88                SINK_TYPE_UPSERT
89            )));
90        }
91        Ok(config)
92    }
93}
94
95#[derive(Debug)]
96pub struct PostgresSink {
97    pub config: PostgresConfig,
98    schema: Schema,
99    pk_indices: Vec<usize>,
100    is_append_only: bool,
101}
102
103impl PostgresSink {
104    pub fn new(
105        config: PostgresConfig,
106        schema: Schema,
107        pk_indices: Vec<usize>,
108        is_append_only: bool,
109    ) -> Result<Self> {
110        Ok(Self {
111            config,
112            schema,
113            pk_indices,
114            is_append_only,
115        })
116    }
117}
118
119impl EnforceSecret for PostgresSink {
120    fn enforce_secret<'a>(
121        prop_iter: impl Iterator<Item = &'a str>,
122    ) -> crate::error::ConnectorResult<()> {
123        for prop in prop_iter {
124            PostgresConfig::enforce_one(prop)?;
125        }
126        Ok(())
127    }
128}
129
130impl TryFrom<SinkParam> for PostgresSink {
131    type Error = SinkError;
132
133    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
134        let schema = param.schema();
135        let config = PostgresConfig::from_btreemap(param.properties)?;
136        PostgresSink::new(
137            config,
138            schema,
139            param.downstream_pk,
140            param.sink_type.is_append_only(),
141        )
142    }
143}
144
145impl Sink for PostgresSink {
146    type Coordinator = DummySinkCommitCoordinator;
147    type LogSinker = PostgresSinkWriter;
148
149    const SINK_NAME: &'static str = POSTGRES_SINK;
150
151    async fn validate(&self) -> Result<()> {
152        if !self.is_append_only && self.pk_indices.is_empty() {
153            return Err(SinkError::Config(anyhow!(
154                "Primary key not defined for upsert Postgres sink (please define in `primary_key` field)"
155            )));
156        }
157
158        // Verify our sink schema is compatible with Postgres
159        {
160            let pg_table = PostgresExternalTable::connect(
161                &self.config.user,
162                &self.config.password,
163                &self.config.host,
164                self.config.port,
165                &self.config.database,
166                &self.config.schema,
167                &self.config.table,
168                &self.config.ssl_mode,
169                &self.config.ssl_root_cert,
170                self.is_append_only,
171            )
172            .await
173            .context(format!(
174                "failed to connect to database: {}, schema: {}, table: {}",
175                &self.config.database, &self.config.schema, &self.config.table
176            ))?;
177
178            // Check that names and types match, order of columns doesn't matter.
179            {
180                let pg_columns = pg_table.column_descs();
181                let sink_columns = self.schema.fields();
182                if pg_columns.len() < sink_columns.len() {
183                    return Err(SinkError::Config(anyhow!(
184                        "Column count mismatch: Postgres table has {} columns, but sink schema has {} columns, sink should have less or equal columns to the Postgres table",
185                        pg_columns.len(),
186                        sink_columns.len()
187                    )));
188                }
189
190                let pg_columns_lookup = pg_columns
191                    .iter()
192                    .map(|c| (c.name.clone(), c.data_type.clone()))
193                    .collect::<BTreeMap<_, _>>();
194                for sink_column in sink_columns {
195                    let pg_column = pg_columns_lookup.get(&sink_column.name);
196                    match pg_column {
197                        None => {
198                            return Err(SinkError::Config(anyhow!(
199                                "Column `{}` not found in Postgres table `{}`",
200                                sink_column.name,
201                                self.config.table
202                            )));
203                        }
204                        Some(pg_column) => {
205                            if !validate_pg_type_to_rw_type(pg_column, &sink_column.data_type()) {
206                                return Err(SinkError::Config(anyhow!(
207                                    "Column `{}` in Postgres table `{}` has type `{}`, but sink schema defines it as type `{}`",
208                                    sink_column.name,
209                                    self.config.table,
210                                    pg_column,
211                                    sink_column.data_type()
212                                )));
213                            }
214                        }
215                    }
216                }
217            }
218
219            // check that pk matches
220            {
221                let pg_pk_names = pg_table.pk_names();
222                let sink_pk_names = self
223                    .pk_indices
224                    .iter()
225                    .map(|i| &self.schema.fields()[*i].name)
226                    .collect::<HashSet<_>>();
227                if pg_pk_names.len() != sink_pk_names.len() {
228                    return Err(SinkError::Config(anyhow!(
229                        "Primary key mismatch: Postgres table has primary key on columns {:?}, but sink schema defines primary key on columns {:?}",
230                        pg_pk_names,
231                        sink_pk_names
232                    )));
233                }
234                for name in pg_pk_names {
235                    if !sink_pk_names.contains(name) {
236                        return Err(SinkError::Config(anyhow!(
237                            "Primary key mismatch: Postgres table has primary key on column `{}`, but sink schema does not define it as a primary key",
238                            name
239                        )));
240                    }
241                }
242            }
243        }
244
245        Ok(())
246    }
247
248    async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
249        PostgresSinkWriter::new(
250            self.config.clone(),
251            self.schema.clone(),
252            self.pk_indices.clone(),
253            self.is_append_only,
254        )
255        .await
256    }
257}
258
259struct ParameterBuffer<'a> {
260    /// A set of parameters to be inserted/deleted.
261    /// Each set is a flattened 2d-array.
262    parameters: Vec<Vec<Option<ScalarAdapter>>>,
263    /// the column dimension (fixed).
264    column_length: usize,
265    /// schema types to serialize into `ScalarAdapter`
266    schema_types: &'a [PgType],
267    /// estimated number of parameters that can be sent in a single query.
268    estimated_parameter_size: usize,
269    /// current parameter buffer to be filled.
270    current_parameter_buffer: Vec<Option<ScalarAdapter>>,
271    /// Parameter upper bound
272    parameter_upper_bound: usize,
273}
274
275impl<'a> ParameterBuffer<'a> {
276    /// The maximum number of parameters that can be sent in a single query.
277    /// See: <https://www.postgresql.org/docs/current/limits.html>
278    /// and <https://github.com/sfackler/rust-postgres/issues/356>
279    const MAX_PARAMETERS: usize = 32768;
280
281    /// `flattened_chunk_size` is the number of datums in a single chunk.
282    fn new(schema_types: &'a [PgType], parameter_upper_bound: usize) -> Self {
283        let estimated_parameter_size = usize::min(Self::MAX_PARAMETERS, parameter_upper_bound);
284        Self {
285            parameters: vec![],
286            column_length: schema_types.len(),
287            schema_types,
288            estimated_parameter_size,
289            current_parameter_buffer: Vec::with_capacity(estimated_parameter_size),
290            parameter_upper_bound,
291        }
292    }
293
294    fn add_row(&mut self, row: impl Row) {
295        assert_eq!(row.len(), self.column_length);
296        if self.current_parameter_buffer.len() + self.column_length > self.parameter_upper_bound {
297            self.new_buffer();
298        }
299        for (i, datum_ref) in row.iter().enumerate() {
300            let pg_datum = datum_ref.map(|s| {
301                let ty = &self.schema_types[i];
302                match ScalarAdapter::from_scalar(s, ty) {
303                    Ok(scalar) => Some(scalar),
304                    Err(e) => {
305                        tracing::error!(error=%e.as_report(), scalar=?s, "Failed to convert scalar to pg value");
306                        None
307                    }
308                }
309            });
310            self.current_parameter_buffer.push(pg_datum.flatten());
311        }
312    }
313
314    fn new_buffer(&mut self) {
315        let filled_buffer = std::mem::replace(
316            &mut self.current_parameter_buffer,
317            Vec::with_capacity(self.estimated_parameter_size),
318        );
319        self.parameters.push(filled_buffer);
320    }
321
322    fn into_parts(self) -> (Vec<Vec<Option<ScalarAdapter>>>, Vec<Option<ScalarAdapter>>) {
323        (self.parameters, self.current_parameter_buffer)
324    }
325}
326
327pub struct PostgresSinkWriter {
328    config: PostgresConfig,
329    pk_indices: Vec<usize>,
330    pk_indices_lookup: HashSet<usize>,
331    is_append_only: bool,
332    client: tokio_postgres::Client,
333    pk_types: Vec<PgType>,
334    schema_types: Vec<PgType>,
335    schema: Schema,
336}
337
338impl PostgresSinkWriter {
339    async fn new(
340        config: PostgresConfig,
341        schema: Schema,
342        pk_indices: Vec<usize>,
343        is_append_only: bool,
344    ) -> Result<Self> {
345        let client = create_pg_client(
346            &config.user,
347            &config.password,
348            &config.host,
349            &config.port.to_string(),
350            &config.database,
351            &config.ssl_mode,
352            &config.ssl_root_cert,
353        )
354        .await?;
355
356        let pk_indices_lookup = pk_indices.iter().copied().collect::<HashSet<_>>();
357
358        // Rewrite schema types for serialization
359        let (pk_types, schema_types) = {
360            let name_to_type = PostgresExternalTable::type_mapping(
361                &config.user,
362                &config.password,
363                &config.host,
364                config.port,
365                &config.database,
366                &config.schema,
367                &config.table,
368                &config.ssl_mode,
369                &config.ssl_root_cert,
370                is_append_only,
371            )
372            .await?;
373            let mut schema_types = Vec::with_capacity(schema.fields.len());
374            let mut pk_types = Vec::with_capacity(pk_indices.len());
375            for (i, field) in schema.fields.iter().enumerate() {
376                let field_name = &field.name;
377                let actual_data_type = name_to_type.get(field_name).map(|t| (*t).clone());
378                let actual_data_type = actual_data_type
379                    .ok_or_else(|| {
380                        SinkError::Config(anyhow!(
381                            "Column `{}` not found in sink schema",
382                            field_name
383                        ))
384                    })?
385                    .clone();
386                if pk_indices_lookup.contains(&i) {
387                    pk_types.push(actual_data_type.clone())
388                }
389                schema_types.push(actual_data_type);
390            }
391            (pk_types, schema_types)
392        };
393
394        let writer = Self {
395            config,
396            pk_indices,
397            pk_indices_lookup,
398            is_append_only,
399            client,
400            pk_types,
401            schema_types,
402            schema,
403        };
404        Ok(writer)
405    }
406
407    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
408        // https://www.postgresql.org/docs/current/limits.html
409        // We have a limit of 65,535 parameters in a single query, as restricted by the PostgreSQL protocol.
410        if self.is_append_only {
411            self.write_batch_append_only(chunk).await
412        } else {
413            self.write_batch_non_append_only(chunk).await
414        }
415    }
416
417    async fn write_batch_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
418        // 1d flattened array of parameters to be inserted.
419        let mut parameter_buffer = ParameterBuffer::new(
420            &self.schema_types,
421            chunk.cardinality() * chunk.data_types().len(),
422        );
423        for (op, row) in chunk.rows() {
424            match op {
425                Op::Insert => {
426                    parameter_buffer.add_row(row);
427                }
428                Op::UpdateInsert | Op::Delete | Op::UpdateDelete => {
429                    bail!(
430                        "append-only sink should not receive update insert, update delete and delete operations"
431                    )
432                }
433            }
434        }
435        let (parameters, remaining) = parameter_buffer.into_parts();
436
437        let mut transaction = self.client.transaction().await?;
438        Self::execute_parameter(
439            Op::Insert,
440            &mut transaction,
441            &self.schema,
442            &self.config.schema,
443            &self.config.table,
444            &self.pk_indices,
445            &self.pk_indices_lookup,
446            parameters,
447            remaining,
448            true,
449        )
450        .await?;
451        transaction.commit().await?;
452
453        Ok(())
454    }
455
456    async fn write_batch_non_append_only(&mut self, chunk: StreamChunk) -> Result<()> {
457        // 1d flattened array of parameters to be inserted.
458        let mut insert_parameter_buffer = ParameterBuffer::new(
459            &self.schema_types,
460            // NOTE(kwannoel):
461            // insert on conflict do update may have multiple
462            // rows on the same PK.
463            // In that case they could encounter the following PG error:
464            // ERROR: ON CONFLICT DO UPDATE command cannot affect row a second time
465            // HINT: Ensure that no rows proposed for insertion within the same command have duplicate constrained values
466            // Given that JDBC sink does not batch their insert on conflict do update,
467            // we can keep the behaviour consistent.
468            //
469            // We may opt for an optimization flag to toggle this behaviour in the future.
470            chunk.data_types().len(),
471        );
472        let mut delete_parameter_buffer =
473            ParameterBuffer::new(&self.pk_types, chunk.cardinality() * self.pk_indices.len());
474        // 1d flattened array of parameters to be deleted.
475        for (op, row) in chunk.rows() {
476            match op {
477                Op::UpdateInsert | Op::Insert => {
478                    insert_parameter_buffer.add_row(row);
479                }
480                Op::UpdateDelete | Op::Delete => {
481                    delete_parameter_buffer.add_row(row.project(&self.pk_indices));
482                }
483            }
484        }
485
486        let (delete_parameters, delete_remaining_parameter) = delete_parameter_buffer.into_parts();
487        let mut transaction = self.client.transaction().await?;
488        Self::execute_parameter(
489            Op::Delete,
490            &mut transaction,
491            &self.schema,
492            &self.config.schema,
493            &self.config.table,
494            &self.pk_indices,
495            &self.pk_indices_lookup,
496            delete_parameters,
497            delete_remaining_parameter,
498            false,
499        )
500        .await?;
501        let (insert_parameters, insert_remaining_parameter) = insert_parameter_buffer.into_parts();
502        Self::execute_parameter(
503            Op::Insert,
504            &mut transaction,
505            &self.schema,
506            &self.config.schema,
507            &self.config.table,
508            &self.pk_indices,
509            &self.pk_indices_lookup,
510            insert_parameters,
511            insert_remaining_parameter,
512            false,
513        )
514        .await?;
515        transaction.commit().await?;
516
517        Ok(())
518    }
519
520    async fn execute_parameter(
521        op: Op,
522        transaction: &mut tokio_postgres::Transaction<'_>,
523        schema: &Schema,
524        schema_name: &str,
525        table_name: &str,
526        pk_indices: &[usize],
527        pk_indices_lookup: &HashSet<usize>,
528        parameters: Vec<Vec<Option<ScalarAdapter>>>,
529        remaining_parameter: Vec<Option<ScalarAdapter>>,
530        append_only: bool,
531    ) -> Result<()> {
532        async fn prepare_statement(
533            transaction: &mut tokio_postgres::Transaction<'_>,
534            op: Op,
535            schema: &Schema,
536            schema_name: &str,
537            table_name: &str,
538            pk_indices: &[usize],
539            pk_indices_lookup: &HashSet<usize>,
540            rows_length: usize,
541            append_only: bool,
542        ) -> Result<(String, tokio_postgres::Statement)> {
543            assert!(rows_length > 0, "parameters are empty");
544            let statement_str = match op {
545                Op::Insert => {
546                    if append_only {
547                        create_insert_sql(schema, schema_name, table_name, rows_length)
548                    } else {
549                        create_upsert_sql(
550                            schema,
551                            schema_name,
552                            table_name,
553                            pk_indices,
554                            pk_indices_lookup,
555                            rows_length,
556                        )
557                    }
558                }
559                Op::Delete => {
560                    create_delete_sql(schema, schema_name, table_name, pk_indices, rows_length)
561                }
562                _ => unreachable!(),
563            };
564            let statement = transaction
565                .prepare(&statement_str)
566                .await
567                .with_context(|| format!("failed to prepare statement: {}", statement_str))?;
568            Ok((statement_str, statement))
569        }
570
571        let column_length = match op {
572            Op::Insert => schema.len(),
573            Op::Delete => pk_indices.len(),
574            _ => unreachable!(),
575        };
576
577        if !parameters.is_empty() {
578            let parameter_length = parameters[0].len();
579            assert_eq!(
580                parameter_length % column_length,
581                0,
582                "flattened parameters are unaligned, parameter_length={} column_length={}",
583                parameter_length,
584                column_length,
585            );
586            let rows_length = parameter_length / column_length;
587            let (statement_str, statement) = prepare_statement(
588                transaction,
589                op,
590                schema,
591                schema_name,
592                table_name,
593                pk_indices,
594                pk_indices_lookup,
595                rows_length,
596                append_only,
597            )
598            .await?;
599            for parameter in parameters {
600                transaction
601                    .execute_raw(&statement, parameter)
602                    .await
603                    .with_context(|| format!("failed to execute statement: {}", statement_str,))?;
604            }
605        }
606        if !remaining_parameter.is_empty() {
607            let parameter_length = remaining_parameter.len();
608            assert_eq!(
609                parameter_length % column_length,
610                0,
611                "flattened parameters are unaligned"
612            );
613            let rows_length = remaining_parameter.len() / column_length;
614            let (statement_str, statement) = prepare_statement(
615                transaction,
616                op,
617                schema,
618                schema_name,
619                table_name,
620                pk_indices,
621                pk_indices_lookup,
622                rows_length,
623                append_only,
624            )
625            .await?;
626            tracing::trace!("binding parameters: {:?}", remaining_parameter);
627            transaction
628                .execute_raw(&statement, remaining_parameter)
629                .await
630                .with_context(|| format!("failed to execute statement: {}", statement_str))?;
631        }
632        Ok(())
633    }
634}
635
636#[async_trait]
637impl LogSinker for PostgresSinkWriter {
638    async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
639        log_reader.start_from(None).await?;
640        loop {
641            let (epoch, item) = log_reader.next_item().await?;
642            match item {
643                LogStoreReadItem::StreamChunk { chunk, chunk_id } => {
644                    self.write_batch(chunk).await?;
645                    log_reader.truncate(TruncateOffset::Chunk { epoch, chunk_id })?;
646                }
647                LogStoreReadItem::Barrier { .. } => {
648                    log_reader.truncate(TruncateOffset::Barrier { epoch })?;
649                }
650            }
651        }
652    }
653}
654
655fn create_insert_sql(
656    schema: &Schema,
657    schema_name: &str,
658    table_name: &str,
659    number_of_rows: usize,
660) -> String {
661    assert!(
662        number_of_rows > 0,
663        "number of parameters must be greater than 0"
664    );
665    let normalized_table_name = format!(
666        "{}.{}",
667        quote_identifier(schema_name),
668        quote_identifier(table_name)
669    );
670    let number_of_columns = schema.len();
671    let columns: String = schema
672        .fields()
673        .iter()
674        .map(|field| quote_identifier(&field.name))
675        .join(", ");
676    let parameters: String = (0..number_of_rows)
677        .map(|i| {
678            let row_parameters = (0..number_of_columns)
679                .map(|j| format!("${}", i * number_of_columns + j + 1))
680                .join(", ");
681            format!("({row_parameters})")
682        })
683        .collect_vec()
684        .join(", ");
685    format!("INSERT INTO {normalized_table_name} ({columns}) VALUES {parameters}")
686}
687
688fn create_delete_sql(
689    schema: &Schema,
690    schema_name: &str,
691    table_name: &str,
692    pk_indices: &[usize],
693    number_of_rows: usize,
694) -> String {
695    assert!(
696        number_of_rows > 0,
697        "number of parameters must be greater than 0"
698    );
699    let normalized_table_name = format!(
700        "{}.{}",
701        quote_identifier(schema_name),
702        quote_identifier(table_name)
703    );
704    let number_of_pk = pk_indices.len();
705    let pk = {
706        let pk_symbols = pk_indices
707            .iter()
708            .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
709            .join(", ");
710        format!("({})", pk_symbols)
711    };
712    let parameters: String = (0..number_of_rows)
713        .map(|i| {
714            let row_parameters: String = (0..pk_indices.len())
715                .map(|j| format!("${}", i * number_of_pk + j + 1))
716                .join(", ");
717            format!("({row_parameters})")
718        })
719        .collect_vec()
720        .join(", ");
721    format!("DELETE FROM {normalized_table_name} WHERE {pk} in ({parameters})")
722}
723
724fn create_upsert_sql(
725    schema: &Schema,
726    schema_name: &str,
727    table_name: &str,
728    pk_indices: &[usize],
729    pk_indices_lookup: &HashSet<usize>,
730    number_of_rows: usize,
731) -> String {
732    let number_of_columns = schema.len();
733    let insert_sql = create_insert_sql(schema, schema_name, table_name, number_of_rows);
734    let pk_columns = pk_indices
735        .iter()
736        .map(|pk_index| quote_identifier(&schema.fields()[*pk_index].name))
737        .collect_vec()
738        .join(", ");
739    let update_parameters: String = (0..number_of_columns)
740        .filter(|i| !pk_indices_lookup.contains(i))
741        .map(|i| {
742            let column = quote_identifier(&schema.fields()[i].name);
743            format!("{column} = EXCLUDED.{column}")
744        })
745        .collect_vec()
746        .join(", ");
747    format!("{insert_sql} on conflict ({pk_columns}) do update set {update_parameters}")
748}
749
750/// Quote an identifier for PostgreSQL.
751fn quote_identifier(identifier: &str) -> String {
752    format!("\"{}\"", identifier.replace("\"", "\"\""))
753}
754
755#[cfg(test)]
756mod tests {
757    use std::fmt::Display;
758
759    use expect_test::{Expect, expect};
760    use risingwave_common::catalog::Field;
761    use risingwave_common::types::DataType;
762
763    use super::*;
764
765    fn check(actual: impl Display, expect: Expect) {
766        let actual = actual.to_string();
767        expect.assert_eq(&actual);
768    }
769
770    #[test]
771    fn test_create_insert_sql() {
772        let schema = Schema::new(vec![
773            Field {
774                data_type: DataType::Int32,
775                name: "a".to_owned(),
776            },
777            Field {
778                data_type: DataType::Int32,
779                name: "b".to_owned(),
780            },
781        ]);
782        let schema_name = "test_schema";
783        let table_name = "test_table";
784        let sql = create_insert_sql(&schema, schema_name, table_name, 3);
785        check(
786            sql,
787            expect![[
788                r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2), ($3, $4), ($5, $6)"#
789            ]],
790        );
791    }
792
793    #[test]
794    fn test_create_delete_sql() {
795        let schema = Schema::new(vec![
796            Field {
797                data_type: DataType::Int32,
798                name: "a".to_owned(),
799            },
800            Field {
801                data_type: DataType::Int32,
802                name: "b".to_owned(),
803            },
804        ]);
805        let schema_name = "test_schema";
806        let table_name = "test_table";
807        let sql = create_delete_sql(&schema, schema_name, table_name, &[1], 3);
808        check(
809            sql,
810            expect![[
811                r#"DELETE FROM "test_schema"."test_table" WHERE ("b") in (($1), ($2), ($3))"#
812            ]],
813        );
814        let table_name = "test_table";
815        let sql = create_delete_sql(&schema, schema_name, table_name, &[0, 1], 3);
816        check(
817            sql,
818            expect![[
819                r#"DELETE FROM "test_schema"."test_table" WHERE ("a", "b") in (($1, $2), ($3, $4), ($5, $6))"#
820            ]],
821        );
822    }
823
824    #[test]
825    fn test_create_upsert_sql() {
826        let schema = Schema::new(vec![
827            Field {
828                data_type: DataType::Int32,
829                name: "a".to_owned(),
830            },
831            Field {
832                data_type: DataType::Int32,
833                name: "b".to_owned(),
834            },
835        ]);
836        let schema_name = "test_schema";
837        let table_name = "test_table";
838        let pk_indices_lookup = HashSet::from_iter([1]);
839        let sql = create_upsert_sql(
840            &schema,
841            schema_name,
842            table_name,
843            &[1],
844            &pk_indices_lookup,
845            3,
846        );
847        check(
848            sql,
849            expect![[
850                r#"INSERT INTO "test_schema"."test_table" ("a", "b") VALUES ($1, $2), ($3, $4), ($5, $6) on conflict ("b") do update set "a" = EXCLUDED."a""#
851            ]],
852        );
853    }
854}