risingwave_connector/sink/snowflake_redshift/
snowflake.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::num::NonZeroU64;
16use std::collections::BTreeMap;
17use std::time::Duration;
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use phf::{Set, phf_set};
22use risingwave_common::array::StreamChunk;
23use risingwave_common::catalog::Schema;
24use risingwave_common::types::DataType;
25use risingwave_pb::connector_service::{SinkMetadata, sink_metadata};
26use risingwave_pb::stream_plan::PbSinkSchemaChange;
27use serde::Deserialize;
28use serde_with::{DisplayFromStr, serde_as};
29use thiserror_ext::AsReport;
30use tokio::sync::mpsc::{UnboundedSender, unbounded_channel};
31use tokio::time::{MissedTickBehavior, interval};
32use tonic::async_trait;
33use with_options::WithOptions;
34
35use crate::connector_common::IcebergSinkCompactionUpdate;
36use crate::enforce_secret::EnforceSecret;
37use crate::sink::catalog::SinkId;
38use crate::sink::coordinate::CoordinatedLogSinker;
39use crate::sink::decouple_checkpoint_log_sink::default_commit_checkpoint_interval;
40use crate::sink::file_sink::s3::S3Common;
41use crate::sink::jdbc_jni_client::{self, JdbcJniClient};
42use crate::sink::snowflake_redshift::{
43    __OP, __ROW_ID, SnowflakeRedshiftSinkJdbcWriter, SnowflakeRedshiftSinkS3Writer,
44};
45use crate::sink::writer::SinkWriter;
46use crate::sink::{
47    Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT,
48    SinglePhaseCommitCoordinator, Sink, SinkCommitCoordinator, SinkError, SinkParam,
49    SinkWriterParam,
50};
51
52pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2";
53
54const AUTH_METHOD_PASSWORD: &str = "password";
55const AUTH_METHOD_KEY_PAIR_FILE: &str = "key_pair_file";
56const AUTH_METHOD_KEY_PAIR_OBJECT: &str = "key_pair_object";
57const PROP_AUTH_METHOD: &str = "auth.method";
58
59pub fn build_full_table_name(database: &str, schema_name: &str, table_name: &str) -> String {
60    format!(r#""{}"."{}"."{}""#, database, schema_name, table_name)
61}
62
63#[serde_as]
64#[derive(Debug, Clone, Deserialize, WithOptions)]
65pub struct SnowflakeV2Config {
66    #[serde(rename = "type")]
67    pub r#type: String,
68
69    #[serde(rename = "intermediate.table.name")]
70    pub snowflake_cdc_table_name: Option<String>,
71
72    #[serde(rename = "table.name")]
73    pub snowflake_target_table_name: Option<String>,
74
75    #[serde(rename = "database")]
76    pub snowflake_database: Option<String>,
77
78    #[serde(rename = "schema")]
79    pub snowflake_schema: Option<String>,
80
81    #[serde(default = "default_target_interval_schedule")]
82    #[serde(rename = "write.target.interval.seconds")]
83    #[serde_as(as = "DisplayFromStr")]
84    pub writer_target_interval_seconds: u64,
85
86    #[serde(default = "default_intermediate_interval_schedule")]
87    #[serde(rename = "write.intermediate.interval.seconds")]
88    #[serde_as(as = "DisplayFromStr")]
89    pub write_intermediate_interval_seconds: u64,
90
91    #[serde(rename = "warehouse")]
92    pub snowflake_warehouse: Option<String>,
93
94    #[serde(rename = "jdbc.url")]
95    pub jdbc_url: Option<String>,
96
97    #[serde(rename = "username")]
98    pub username: Option<String>,
99
100    #[serde(rename = "password")]
101    pub password: Option<String>,
102
103    // Authentication method control (password | key_pair_file | key_pair_object)
104    #[serde(rename = "auth.method")]
105    pub auth_method: Option<String>,
106
107    // Key-pair authentication via connection Properties (Option 2: file-based)
108    #[serde(rename = "private_key_file")]
109    pub private_key_file: Option<String>,
110
111    #[serde(rename = "private_key_file_pwd")]
112    pub private_key_file_pwd: Option<String>,
113
114    // Key-pair authentication via connection Properties (Option 1: object-based, PEM content)
115    #[serde(rename = "private_key_pem")]
116    pub private_key_pem: Option<String>,
117
118    /// Commit every n(>0) checkpoints, default is 10.
119    #[serde(default = "default_commit_checkpoint_interval")]
120    #[serde_as(as = "DisplayFromStr")]
121    #[with_option(allow_alter_on_fly)]
122    pub commit_checkpoint_interval: u64,
123
124    /// Enable auto schema change for upsert sink.
125    /// If enabled, the sink will automatically alter the target table to add new columns.
126    #[serde(default)]
127    #[serde(rename = "auto.schema.change")]
128    #[serde_as(as = "DisplayFromStr")]
129    pub auto_schema_change: bool,
130
131    #[serde(default)]
132    #[serde(rename = "create_table_if_not_exists")]
133    #[serde_as(as = "DisplayFromStr")]
134    pub create_table_if_not_exists: bool,
135
136    #[serde(default = "default_with_s3")]
137    #[serde(rename = "with_s3")]
138    #[serde_as(as = "DisplayFromStr")]
139    pub with_s3: bool,
140
141    #[serde(flatten)]
142    pub s3_inner: Option<S3Common>,
143
144    #[serde(rename = "stage")]
145    pub stage: Option<String>,
146}
147
148fn default_target_interval_schedule() -> u64 {
149    3600 // Default to 1 hour
150}
151
152fn default_intermediate_interval_schedule() -> u64 {
153    1800 // Default to 0.5 hour
154}
155
156fn default_with_s3() -> bool {
157    true
158}
159
160impl SnowflakeV2Config {
161    /// Build JDBC Properties for the Snowflake JDBC connection (no URL parameters).
162    /// Returns (`jdbc_url`, `driver_properties`).
163    /// - `driver_properties` are transformed/used by the Java runner and passed to `DriverManager::getConnection(url, props)`
164    ///
165    /// Note: This method assumes the config has been validated by `from_btreemap`.
166    pub fn build_jdbc_connection_properties(&self) -> Result<(String, Vec<(String, String)>)> {
167        let jdbc_url = self
168            .jdbc_url
169            .clone()
170            .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?;
171        let username = self
172            .username
173            .clone()
174            .ok_or(SinkError::Config(anyhow!("username is required")))?;
175
176        let mut connection_properties: Vec<(String, String)> = vec![("user".to_owned(), username)];
177
178        // auth_method is guaranteed to be Some after validation in from_btreemap
179        match self.auth_method.as_deref().unwrap() {
180            AUTH_METHOD_PASSWORD => {
181                // password is guaranteed to exist by from_btreemap validation
182                connection_properties.push(("password".to_owned(), self.password.clone().unwrap()));
183            }
184            AUTH_METHOD_KEY_PAIR_FILE => {
185                // private_key_file is guaranteed to exist by from_btreemap validation
186                connection_properties.push((
187                    "private_key_file".to_owned(),
188                    self.private_key_file.clone().unwrap(),
189                ));
190                if let Some(pwd) = self.private_key_file_pwd.clone() {
191                    connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
192                }
193            }
194            AUTH_METHOD_KEY_PAIR_OBJECT => {
195                connection_properties.push((
196                    PROP_AUTH_METHOD.to_owned(),
197                    AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
198                ));
199                // private_key_pem is guaranteed to exist by from_btreemap validation
200                connection_properties.push((
201                    "private_key_pem".to_owned(),
202                    self.private_key_pem.clone().unwrap(),
203                ));
204                if let Some(pwd) = self.private_key_file_pwd.clone() {
205                    connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
206                }
207            }
208            _ => {
209                // This should never happen since from_btreemap validates auth_method
210                unreachable!(
211                    "Invalid auth_method - should have been caught during config validation"
212                )
213            }
214        }
215
216        Ok((jdbc_url, connection_properties))
217    }
218
219    pub fn from_btreemap(properties: &BTreeMap<String, String>) -> Result<Self> {
220        let mut config =
221            serde_json::from_value::<SnowflakeV2Config>(serde_json::to_value(properties).unwrap())
222                .map_err(|e| SinkError::Config(anyhow!(e)))?;
223        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
224            return Err(SinkError::Config(anyhow!(
225                "`{}` must be {}, or {}",
226                SINK_TYPE_OPTION,
227                SINK_TYPE_APPEND_ONLY,
228                SINK_TYPE_UPSERT
229            )));
230        }
231
232        // Normalize and validate authentication method
233        let has_password = config.password.is_some();
234        let has_file = config.private_key_file.is_some();
235        let has_pem = config.private_key_pem.as_deref().is_some();
236
237        let normalized_auth_method = match config
238            .auth_method
239            .as_deref()
240            .map(|s| s.trim().to_ascii_lowercase())
241        {
242            Some(method) if method == AUTH_METHOD_PASSWORD => {
243                if !has_password {
244                    return Err(SinkError::Config(anyhow!(
245                        "auth.method=password requires `password`"
246                    )));
247                }
248                if has_file || has_pem {
249                    return Err(SinkError::Config(anyhow!(
250                        "auth.method=password must not set `private_key_file`/`private_key_pem`"
251                    )));
252                }
253                AUTH_METHOD_PASSWORD.to_owned()
254            }
255            Some(method) if method == AUTH_METHOD_KEY_PAIR_FILE => {
256                if !has_file {
257                    return Err(SinkError::Config(anyhow!(
258                        "auth.method=key_pair_file requires `private_key_file`"
259                    )));
260                }
261                if has_password {
262                    return Err(SinkError::Config(anyhow!(
263                        "auth.method=key_pair_file must not set `password`"
264                    )));
265                }
266                if has_pem {
267                    return Err(SinkError::Config(anyhow!(
268                        "auth.method=key_pair_file must not set `private_key_pem`"
269                    )));
270                }
271                AUTH_METHOD_KEY_PAIR_FILE.to_owned()
272            }
273            Some(method) if method == AUTH_METHOD_KEY_PAIR_OBJECT => {
274                if !has_pem {
275                    return Err(SinkError::Config(anyhow!(
276                        "auth.method=key_pair_object requires `private_key_pem`"
277                    )));
278                }
279                if has_password {
280                    return Err(SinkError::Config(anyhow!(
281                        "auth.method=key_pair_object must not set `password`"
282                    )));
283                }
284                AUTH_METHOD_KEY_PAIR_OBJECT.to_owned()
285            }
286            Some(other) => {
287                return Err(SinkError::Config(anyhow!(
288                    "invalid auth.method: {} (allowed: password | key_pair_file | key_pair_object)",
289                    other
290                )));
291            }
292            None => {
293                // Infer auth method from supplied fields
294                match (has_password, has_file, has_pem) {
295                    (true, false, false) => AUTH_METHOD_PASSWORD.to_owned(),
296                    (false, true, false) => AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
297                    (false, false, true) => AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
298                    (true, true, _) | (true, _, true) | (false, true, true) => {
299                        return Err(SinkError::Config(anyhow!(
300                            "ambiguous auth: multiple auth options provided; remove one or set `auth.method`"
301                        )));
302                    }
303                    _ => {
304                        return Err(SinkError::Config(anyhow!(
305                            "no authentication configured: set either `password`, or `private_key_file`, or `private_key_pem` (or provide `auth.method`)"
306                        )));
307                    }
308                }
309            }
310        };
311        config.auth_method = Some(normalized_auth_method);
312        Ok(config)
313    }
314
315    pub fn build_snowflake_task_ctx_jdbc_client(
316        &self,
317        is_append_only: bool,
318        schema: &Schema,
319        pk_indices: &Vec<usize>,
320    ) -> Result<Option<(SnowflakeTaskContext, JdbcJniClient)>> {
321        if !self.auto_schema_change
322            && is_append_only
323            && !self.create_table_if_not_exists
324            && !self.with_s3
325        {
326            // append-only + no auto schema change is not need to create a client
327            return Ok(None);
328        }
329        let target_table_name = self
330            .snowflake_target_table_name
331            .clone()
332            .ok_or(SinkError::Config(anyhow!("table.name is required")))?;
333        let database = self
334            .snowflake_database
335            .clone()
336            .ok_or(SinkError::Config(anyhow!("database is required")))?;
337        let schema_name = self
338            .snowflake_schema
339            .clone()
340            .ok_or(SinkError::Config(anyhow!("schema is required")))?;
341        let mut snowflake_task_ctx = SnowflakeTaskContext {
342            target_table_name: target_table_name.clone(),
343            database,
344            schema_name,
345            schema: schema.clone(),
346            ..Default::default()
347        };
348
349        let (jdbc_url, connection_properties) = self.build_jdbc_connection_properties()?;
350        let client = JdbcJniClient::new_with_props(jdbc_url, connection_properties)?;
351
352        if self.with_s3 {
353            let stage = self
354                .stage
355                .clone()
356                .ok_or(SinkError::Config(anyhow!("stage is required")))?;
357            snowflake_task_ctx.stage = Some(stage);
358            snowflake_task_ctx.pipe_name = Some(format!("{}_pipe", target_table_name));
359        }
360        if !is_append_only {
361            let cdc_table_name = self
362                .snowflake_cdc_table_name
363                .clone()
364                .ok_or(SinkError::Config(anyhow!(
365                    "intermediate.table.name is required"
366                )))?;
367            snowflake_task_ctx.cdc_table_name = Some(cdc_table_name.clone());
368            snowflake_task_ctx.writer_target_interval_seconds = self.writer_target_interval_seconds;
369            snowflake_task_ctx.warehouse = Some(
370                self.snowflake_warehouse
371                    .clone()
372                    .ok_or(SinkError::Config(anyhow!("warehouse is required")))?,
373            );
374            let pk_column_names: Vec<_> = schema
375                .fields
376                .iter()
377                .enumerate()
378                .filter(|(index, _)| pk_indices.contains(index))
379                .map(|(_, field)| field.name.clone())
380                .collect();
381            if pk_column_names.is_empty() {
382                return Err(SinkError::Config(anyhow!(
383                    "Primary key columns not found. Please set the `primary_key` column in the sink properties, or ensure that the sink contains the primary key columns from the upstream."
384                )));
385            }
386            snowflake_task_ctx.pk_column_names = Some(pk_column_names);
387            snowflake_task_ctx.all_column_names = Some(
388                schema
389                    .fields
390                    .iter()
391                    .map(|field| field.name.clone())
392                    .collect(),
393            );
394            snowflake_task_ctx.task_name = Some(format!(
395                "rw_snowflake_sink_from_{cdc_table_name}_to_{target_table_name}"
396            ));
397        }
398        Ok(Some((snowflake_task_ctx, client)))
399    }
400}
401
402impl EnforceSecret for SnowflakeV2Config {
403    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
404        "username",
405        "password",
406        "jdbc.url",
407        // Key-pair authentication secrets
408        "private_key_file_pwd",
409        "private_key_pem",
410    };
411}
412
413#[derive(Clone, Debug)]
414pub struct SnowflakeV2Sink {
415    config: SnowflakeV2Config,
416    schema: Schema,
417    pk_indices: Vec<usize>,
418    is_append_only: bool,
419    param: SinkParam,
420}
421
422impl EnforceSecret for SnowflakeV2Sink {
423    fn enforce_secret<'a>(
424        prop_iter: impl Iterator<Item = &'a str>,
425    ) -> crate::sink::ConnectorResult<()> {
426        for prop in prop_iter {
427            SnowflakeV2Config::enforce_one(prop)?;
428        }
429        Ok(())
430    }
431}
432
433impl TryFrom<SinkParam> for SnowflakeV2Sink {
434    type Error = SinkError;
435
436    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
437        let schema = param.schema();
438        let config = SnowflakeV2Config::from_btreemap(&param.properties)?;
439        let is_append_only = param.sink_type.is_append_only();
440        let pk_indices = param.downstream_pk_or_empty();
441        Ok(Self {
442            config,
443            schema,
444            pk_indices,
445            is_append_only,
446            param,
447        })
448    }
449}
450
451impl Sink for SnowflakeV2Sink {
452    type LogSinker = CoordinatedLogSinker<SnowflakeSinkWriter>;
453
454    const SINK_NAME: &'static str = SNOWFLAKE_SINK_V2;
455
456    async fn validate(&self) -> Result<()> {
457        risingwave_common::license::Feature::SnowflakeSink
458            .check_available()
459            .map_err(|e| anyhow::anyhow!(e))?;
460        if let Some((snowflake_task_ctx, client)) =
461            self.config.build_snowflake_task_ctx_jdbc_client(
462                self.is_append_only,
463                &self.schema,
464                &self.pk_indices,
465            )?
466        {
467            let client = SnowflakeJniClient::new(client, snowflake_task_ctx);
468            client.execute_create_table().await?;
469            client.execute_create_pipe().await?;
470        }
471
472        Ok(())
473    }
474
475    fn support_schema_change() -> bool {
476        true
477    }
478
479    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
480        SnowflakeV2Config::from_btreemap(config)?;
481        Ok(())
482    }
483
484    async fn new_log_sinker(
485        &self,
486        writer_param: crate::sink::SinkWriterParam,
487    ) -> Result<Self::LogSinker> {
488        let writer = SnowflakeSinkWriter::new(
489            self.config.clone(),
490            self.is_append_only,
491            writer_param.clone(),
492            self.param.clone(),
493        )
494        .await?;
495
496        let commit_checkpoint_interval =
497            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
498                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
499            );
500
501        CoordinatedLogSinker::new(
502            &writer_param,
503            self.param.clone(),
504            writer,
505            commit_checkpoint_interval,
506        )
507        .await
508    }
509
510    fn is_coordinated_sink(&self) -> bool {
511        true
512    }
513
514    async fn new_coordinator(
515        &self,
516        _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
517    ) -> Result<SinkCommitCoordinator> {
518        let coordinator = SnowflakeSinkCommitter::new(
519            self.config.clone(),
520            &self.schema,
521            &self.pk_indices,
522            self.is_append_only,
523            self.param.sink_id,
524        )?;
525        Ok(SinkCommitCoordinator::SinglePhase(Box::new(coordinator)))
526    }
527}
528
529pub enum SnowflakeSinkWriter {
530    S3(SnowflakeRedshiftSinkS3Writer),
531    Jdbc(SnowflakeRedshiftSinkJdbcWriter),
532}
533
534impl SnowflakeSinkWriter {
535    pub async fn new(
536        config: SnowflakeV2Config,
537        is_append_only: bool,
538        writer_param: SinkWriterParam,
539        param: SinkParam,
540    ) -> Result<Self> {
541        let schema = param.schema();
542        let database = config.snowflake_database.ok_or_else(|| {
543            SinkError::Config(anyhow!("database is required for Snowflake JDBC sink"))
544        })?;
545        let schema_name = config.snowflake_schema.ok_or_else(|| {
546            SinkError::Config(anyhow!("schema is required for Snowflake JDBC sink"))
547        })?;
548        let table_name = config.snowflake_target_table_name.ok_or_else(|| {
549            SinkError::Config(anyhow!("table.name is required for Snowflake JDBC sink"))
550        })?;
551        if config.with_s3 {
552            let s3_writer = SnowflakeRedshiftSinkS3Writer::new(
553                config.s3_inner.ok_or_else(|| {
554                    SinkError::Config(anyhow!(
555                        "S3 configuration is required for Snowflake S3 sink"
556                    ))
557                })?,
558                schema,
559                is_append_only,
560                table_name,
561            )?;
562            Ok(Self::S3(s3_writer))
563        } else {
564            let jdbc_writer = SnowflakeRedshiftSinkJdbcWriter::new(
565                is_append_only,
566                writer_param,
567                param,
568                build_full_table_name(&database, &schema_name, &table_name),
569            )
570            .await?;
571            Ok(Self::Jdbc(jdbc_writer))
572        }
573    }
574}
575
576#[async_trait]
577impl SinkWriter for SnowflakeSinkWriter {
578    type CommitMetadata = Option<SinkMetadata>;
579
580    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
581        match self {
582            Self::S3(writer) => writer.begin_epoch(epoch),
583            Self::Jdbc(writer) => writer.begin_epoch(epoch).await,
584        }
585    }
586
587    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
588        match self {
589            Self::S3(writer) => writer.write_batch(chunk).await,
590            Self::Jdbc(writer) => writer.write_batch(chunk).await,
591        }
592    }
593
594    async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
595        match self {
596            Self::S3(writer) => {
597                writer.barrier(is_checkpoint).await?;
598            }
599            Self::Jdbc(writer) => {
600                writer.barrier(is_checkpoint).await?;
601            }
602        }
603        Ok(Some(SinkMetadata {
604            metadata: Some(sink_metadata::Metadata::Serialized(
605                risingwave_pb::connector_service::sink_metadata::SerializedMetadata {
606                    metadata: vec![],
607                },
608            )),
609        }))
610    }
611
612    async fn abort(&mut self) -> Result<()> {
613        if let Self::Jdbc(writer) = self {
614            writer.abort().await
615        } else {
616            Ok(())
617        }
618    }
619}
620
621#[derive(Default, Clone)]
622pub struct SnowflakeTaskContext {
623    // required for task creation
624    pub target_table_name: String,
625    pub database: String,
626    pub schema_name: String,
627    pub schema: Schema,
628
629    // only upsert
630    pub task_name: Option<String>,
631    pub cdc_table_name: Option<String>,
632    pub writer_target_interval_seconds: u64,
633    pub warehouse: Option<String>,
634    pub pk_column_names: Option<Vec<String>>,
635    pub all_column_names: Option<Vec<String>>,
636
637    // only s3 writer
638    pub stage: Option<String>,
639    pub pipe_name: Option<String>,
640}
641pub struct SnowflakeSinkCommitter {
642    client: Option<SnowflakeJniClient>,
643    _periodic_task_handle: Option<tokio::task::JoinHandle<()>>,
644    shutdown_sender: Option<tokio::sync::mpsc::UnboundedSender<()>>,
645}
646
647impl SnowflakeSinkCommitter {
648    pub fn new(
649        config: SnowflakeV2Config,
650        schema: &Schema,
651        pk_indices: &Vec<usize>,
652        is_append_only: bool,
653        sink_id: SinkId,
654    ) -> Result<Self> {
655        let (client, periodic_task_handle, shutdown_sender) =
656            if let Some((snowflake_task_ctx, client)) =
657                config.build_snowflake_task_ctx_jdbc_client(is_append_only, schema, pk_indices)?
658            {
659                let (shutdown_sender, shutdown_receiver) = unbounded_channel();
660                let snowflake_client =
661                    SnowflakeJniClient::new(client.clone(), snowflake_task_ctx.clone());
662                let periodic_task_handle = tokio::spawn(async move {
663                    Self::run_periodic_query_task(
664                        snowflake_client,
665                        config.write_intermediate_interval_seconds,
666                        sink_id,
667                        shutdown_receiver,
668                    )
669                    .await;
670                });
671                (
672                    Some(SnowflakeJniClient::new(client, snowflake_task_ctx)),
673                    Some(periodic_task_handle),
674                    Some(shutdown_sender),
675                )
676            } else {
677                (None, None, None)
678            };
679
680        Ok(Self {
681            client,
682            _periodic_task_handle: periodic_task_handle,
683            shutdown_sender,
684        })
685    }
686
687    async fn run_periodic_query_task(
688        client: SnowflakeJniClient,
689        write_intermediate_interval_seconds: u64,
690        sink_id: SinkId,
691        mut shutdown_receiver: tokio::sync::mpsc::UnboundedReceiver<()>,
692    ) {
693        let mut copy_timer = interval(Duration::from_secs(write_intermediate_interval_seconds));
694        copy_timer.set_missed_tick_behavior(MissedTickBehavior::Skip);
695        loop {
696            tokio::select! {
697                _ = shutdown_receiver.recv() => break,
698                _ = copy_timer.tick() => {
699                    if let Err(e) = async {
700                        client.execute_flush_pipe().await?;
701                        Ok::<(),SinkError>(())
702                    }.await {
703                        tracing::error!("Failed to execute copy into task for sink id {}: {}", sink_id, e.as_report());
704                    }
705                }
706            }
707        }
708        tracing::info!("Periodic query task stopped for sink id {}", sink_id);
709    }
710}
711
712#[async_trait]
713impl SinglePhaseCommitCoordinator for SnowflakeSinkCommitter {
714    async fn init(&mut self) -> Result<()> {
715        if let Some(client) = &self.client {
716            // Todo: move this to validate
717            client.execute_create_pipe().await?;
718            client.execute_create_merge_into_task().await?;
719        }
720        Ok(())
721    }
722
723    async fn commit_data(&mut self, _epoch: u64, _metadata: Vec<SinkMetadata>) -> Result<()> {
724        Ok(())
725    }
726
727    async fn commit_schema_change(
728        &mut self,
729        _epoch: u64,
730        schema_change: PbSinkSchemaChange,
731    ) -> Result<()> {
732        use risingwave_pb::stream_plan::sink_schema_change::PbOp as SinkSchemaChangeOp;
733        let schema_change_op = schema_change
734            .op
735            .ok_or_else(|| SinkError::Coordinator(anyhow!("Invalid schema change operation")))?;
736        let SinkSchemaChangeOp::AddColumns(add_columns) = schema_change_op else {
737            return Err(SinkError::Coordinator(anyhow!(
738                "Only AddColumns schema change is supported for Snowflake sink"
739            )));
740        };
741        let client = self.client.as_mut().ok_or_else(|| {
742            SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
743        })?;
744        client
745            .execute_alter_add_columns(
746                &add_columns
747                    .fields
748                    .into_iter()
749                    .map(|f| (f.name, DataType::from(f.data_type.unwrap()).to_string()))
750                    .collect_vec(),
751            )
752            .await
753    }
754}
755
756impl Drop for SnowflakeSinkCommitter {
757    fn drop(&mut self) {
758        if let Some(client) = self.client.take() {
759            if let Some(sender) = self.shutdown_sender.take() {
760                let _ = sender.send(()); // Ignore the result, as the receiver may have been dropped.
761            }
762            tokio::spawn(async move {
763                client.execute_drop_task().await.ok();
764            });
765        }
766    }
767}
768
769pub struct SnowflakeJniClient {
770    jdbc_client: JdbcJniClient,
771    snowflake_task_context: SnowflakeTaskContext,
772}
773
774impl SnowflakeJniClient {
775    pub fn new(jdbc_client: JdbcJniClient, snowflake_task_context: SnowflakeTaskContext) -> Self {
776        Self {
777            jdbc_client,
778            snowflake_task_context,
779        }
780    }
781
782    pub async fn execute_alter_add_columns(
783        &mut self,
784        columns: &Vec<(String, String)>,
785    ) -> Result<()> {
786        self.execute_drop_task().await?;
787        if let Some(names) = self.snowflake_task_context.all_column_names.as_mut() {
788            names.extend(columns.iter().map(|(name, _)| name.clone()));
789        }
790        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
791            let alter_add_column_cdc_table_sql = build_alter_add_column_sql(
792                cdc_table_name,
793                &self.snowflake_task_context.database,
794                &self.snowflake_task_context.schema_name,
795                columns,
796            );
797            self.jdbc_client
798                .execute_sql_sync(vec![alter_add_column_cdc_table_sql])
799                .await?;
800        }
801
802        let alter_add_column_target_table_sql = build_alter_add_column_sql(
803            &self.snowflake_task_context.target_table_name,
804            &self.snowflake_task_context.database,
805            &self.snowflake_task_context.schema_name,
806            columns,
807        );
808        self.jdbc_client
809            .execute_sql_sync(vec![alter_add_column_target_table_sql])
810            .await?;
811
812        self.execute_create_merge_into_task().await?;
813        Ok(())
814    }
815
816    pub async fn execute_create_merge_into_task(&self) -> Result<()> {
817        if self.snowflake_task_context.task_name.is_some() {
818            let create_task_sql = build_create_merge_into_task_sql(&self.snowflake_task_context);
819            let start_task_sql = build_start_task_sql(&self.snowflake_task_context);
820            self.jdbc_client
821                .execute_sql_sync(vec![create_task_sql])
822                .await?;
823            self.jdbc_client
824                .execute_sql_sync(vec![start_task_sql])
825                .await?;
826        }
827        Ok(())
828    }
829
830    pub async fn execute_drop_task(&self) -> Result<()> {
831        if self.snowflake_task_context.task_name.is_some() {
832            let sql = build_drop_task_sql(&self.snowflake_task_context);
833            if let Err(e) = self.jdbc_client.execute_sql_sync(vec![sql]).await {
834                tracing::error!(
835                    "Failed to drop Snowflake sink task {:?}: {:?}",
836                    self.snowflake_task_context.task_name,
837                    e.as_report()
838                );
839            } else {
840                tracing::info!(
841                    "Snowflake sink task {:?} dropped",
842                    self.snowflake_task_context.task_name
843                );
844            }
845        }
846        Ok(())
847    }
848
849    pub async fn execute_create_table(&self) -> Result<()> {
850        // create target table
851        let create_target_table_sql = build_create_table_sql(
852            &self.snowflake_task_context.target_table_name,
853            &self.snowflake_task_context.database,
854            &self.snowflake_task_context.schema_name,
855            &self.snowflake_task_context.schema,
856            false,
857        )?;
858        self.jdbc_client
859            .execute_sql_sync(vec![create_target_table_sql])
860            .await?;
861        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
862            let create_cdc_table_sql = build_create_table_sql(
863                cdc_table_name,
864                &self.snowflake_task_context.database,
865                &self.snowflake_task_context.schema_name,
866                &self.snowflake_task_context.schema,
867                true,
868            )?;
869            self.jdbc_client
870                .execute_sql_sync(vec![create_cdc_table_sql])
871                .await?;
872        }
873        Ok(())
874    }
875
876    pub async fn execute_create_pipe(&self) -> Result<()> {
877        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
878            let table_name =
879                if let Some(table_name) = self.snowflake_task_context.cdc_table_name.as_ref() {
880                    table_name
881                } else {
882                    &self.snowflake_task_context.target_table_name
883                };
884            let create_pipe_sql = build_create_pipe_sql(
885                table_name,
886                &self.snowflake_task_context.database,
887                &self.snowflake_task_context.schema_name,
888                self.snowflake_task_context.stage.as_ref().ok_or_else(|| {
889                    SinkError::Config(anyhow!("snowflake.stage is required for S3 writer"))
890                })?,
891                pipe_name,
892                &self.snowflake_task_context.target_table_name,
893            );
894            self.jdbc_client
895                .execute_sql_sync(vec![create_pipe_sql])
896                .await?;
897        }
898        Ok(())
899    }
900
901    pub async fn execute_flush_pipe(&self) -> Result<()> {
902        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
903            let flush_pipe_sql = build_flush_pipe_sql(
904                &self.snowflake_task_context.database,
905                &self.snowflake_task_context.schema_name,
906                pipe_name,
907            );
908            self.jdbc_client
909                .execute_sql_sync(vec![flush_pipe_sql])
910                .await?;
911        }
912        Ok(())
913    }
914}
915
916fn build_create_table_sql(
917    table_name: &str,
918    database: &str,
919    schema_name: &str,
920    schema: &Schema,
921    need_op_and_row_id: bool,
922) -> Result<String> {
923    let full_table_name = build_full_table_name(database, schema_name, table_name);
924    let mut columns: Vec<String> = schema
925        .fields
926        .iter()
927        .map(|field| {
928            let data_type = convert_snowflake_data_type(&field.data_type)?;
929            Ok(format!(r#""{}" {}"#, field.name, data_type))
930        })
931        .collect::<Result<Vec<String>>>()?;
932    if need_op_and_row_id {
933        columns.push(format!(r#""{}" STRING"#, __ROW_ID));
934        columns.push(format!(r#""{}" INT"#, __OP));
935    }
936    let columns_str = columns.join(", ");
937    Ok(format!(
938        "CREATE TABLE IF NOT EXISTS {} ({}) ENABLE_SCHEMA_EVOLUTION  = true",
939        full_table_name, columns_str
940    ))
941}
942
943fn convert_snowflake_data_type(data_type: &DataType) -> Result<String> {
944    let data_type = match data_type {
945        DataType::Int16 => "SMALLINT".to_owned(),
946        DataType::Int32 => "INTEGER".to_owned(),
947        DataType::Int64 => "BIGINT".to_owned(),
948        DataType::Float32 => "FLOAT4".to_owned(),
949        DataType::Float64 => "FLOAT8".to_owned(),
950        DataType::Boolean => "BOOLEAN".to_owned(),
951        DataType::Varchar => "STRING".to_owned(),
952        DataType::Date => "DATE".to_owned(),
953        DataType::Timestamp => "TIMESTAMP".to_owned(),
954        DataType::Timestamptz => "TIMESTAMP_TZ".to_owned(),
955        DataType::Jsonb => "STRING".to_owned(),
956        DataType::Decimal => "DECIMAL".to_owned(),
957        DataType::Bytea => "BINARY".to_owned(),
958        DataType::Time => "TIME".to_owned(),
959        _ => {
960            return Err(SinkError::Config(anyhow!(
961                "Dont support auto create table for datatype: {}",
962                data_type
963            )));
964        }
965    };
966    Ok(data_type)
967}
968
969fn build_create_pipe_sql(
970    table_name: &str,
971    database: &str,
972    schema: &str,
973    stage: &str,
974    pipe_name: &str,
975    target_table_name: &str,
976) -> String {
977    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
978    let stage = format!(
979        r#""{}"."{}"."{}"/{}"#,
980        database, schema, stage, target_table_name
981    );
982    let table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
983    format!(
984        "CREATE OR REPLACE PIPE {} AUTO_INGEST = FALSE AS COPY INTO {} FROM @{} MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE FILE_FORMAT = (type = 'JSON');",
985        pipe_name, table_name, stage
986    )
987}
988
989fn build_flush_pipe_sql(database: &str, schema: &str, pipe_name: &str) -> String {
990    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
991    format!("ALTER PIPE {} REFRESH;", pipe_name,)
992}
993
994fn build_alter_add_column_sql(
995    table_name: &str,
996    database: &str,
997    schema: &str,
998    columns: &Vec<(String, String)>,
999) -> String {
1000    let full_table_name = build_full_table_name(database, schema, table_name);
1001    jdbc_jni_client::build_alter_add_column_sql(&full_table_name, columns, true)
1002}
1003
1004fn build_start_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1005    let SnowflakeTaskContext {
1006        task_name,
1007        database,
1008        schema_name: schema,
1009        ..
1010    } = snowflake_task_context;
1011    let full_task_name = format!(
1012        r#""{}"."{}"."{}""#,
1013        database,
1014        schema,
1015        task_name.as_ref().unwrap()
1016    );
1017    format!("ALTER TASK {} RESUME", full_task_name)
1018}
1019
1020fn build_drop_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1021    let SnowflakeTaskContext {
1022        task_name,
1023        database,
1024        schema_name: schema,
1025        ..
1026    } = snowflake_task_context;
1027    let full_task_name = format!(
1028        r#""{}"."{}"."{}""#,
1029        database,
1030        schema,
1031        task_name.as_ref().unwrap()
1032    );
1033    format!("DROP TASK IF EXISTS {}", full_task_name)
1034}
1035
1036fn build_create_merge_into_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1037    let SnowflakeTaskContext {
1038        task_name,
1039        cdc_table_name,
1040        target_table_name,
1041        writer_target_interval_seconds,
1042        warehouse,
1043        pk_column_names,
1044        all_column_names,
1045        database,
1046        schema_name,
1047        ..
1048    } = snowflake_task_context;
1049    let full_task_name = format!(
1050        r#""{}"."{}"."{}""#,
1051        database,
1052        schema_name,
1053        task_name.as_ref().unwrap()
1054    );
1055    let full_cdc_table_name = format!(
1056        r#""{}"."{}"."{}""#,
1057        database,
1058        schema_name,
1059        cdc_table_name.as_ref().unwrap()
1060    );
1061    let full_target_table_name = format!(
1062        r#""{}"."{}"."{}""#,
1063        database, schema_name, target_table_name
1064    );
1065
1066    let pk_names_str = pk_column_names
1067        .as_ref()
1068        .unwrap()
1069        .iter()
1070        .map(|name| format!(r#""{}""#, name))
1071        .collect::<Vec<String>>()
1072        .join(", ");
1073    let pk_names_eq_str = pk_column_names
1074        .as_ref()
1075        .unwrap()
1076        .iter()
1077        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1078        .collect::<Vec<String>>()
1079        .join(" AND ");
1080    let all_column_names_set_str = all_column_names
1081        .as_ref()
1082        .unwrap()
1083        .iter()
1084        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1085        .collect::<Vec<String>>()
1086        .join(", ");
1087    let all_column_names_str = all_column_names
1088        .as_ref()
1089        .unwrap()
1090        .iter()
1091        .map(|name| format!(r#""{}""#, name))
1092        .collect::<Vec<String>>()
1093        .join(", ");
1094    let all_column_names_insert_str = all_column_names
1095        .as_ref()
1096        .unwrap()
1097        .iter()
1098        .map(|name| format!(r#"source."{}""#, name))
1099        .collect::<Vec<String>>()
1100        .join(", ");
1101
1102    format!(
1103        r#"CREATE OR REPLACE TASK {task_name}
1104WAREHOUSE = {warehouse}
1105SCHEDULE = '{writer_target_interval_seconds} SECONDS'
1106AS
1107BEGIN
1108    LET max_row_id STRING;
1109
1110    SELECT COALESCE(MAX("{snowflake_sink_row_id}"), '0') INTO :max_row_id
1111    FROM {cdc_table_name};
1112
1113    MERGE INTO {target_table_name} AS target
1114    USING (
1115        SELECT *
1116        FROM (
1117            SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_names_str} ORDER BY "{snowflake_sink_row_id}" DESC) AS dedupe_id
1118            FROM {cdc_table_name}
1119            WHERE "{snowflake_sink_row_id}" <= :max_row_id
1120        ) AS subquery
1121        WHERE dedupe_id = 1
1122    ) AS source
1123    ON {pk_names_eq_str}
1124    WHEN MATCHED AND source."{snowflake_sink_op}" IN (2, 4) THEN DELETE
1125    WHEN MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN UPDATE SET {all_column_names_set_str}
1126    WHEN NOT MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN INSERT ({all_column_names_str}) VALUES ({all_column_names_insert_str});
1127
1128    DELETE FROM {cdc_table_name}
1129    WHERE "{snowflake_sink_row_id}" <= :max_row_id;
1130END;"#,
1131        task_name = full_task_name,
1132        warehouse = warehouse.as_ref().unwrap(),
1133        writer_target_interval_seconds = writer_target_interval_seconds,
1134        cdc_table_name = full_cdc_table_name,
1135        target_table_name = full_target_table_name,
1136        pk_names_str = pk_names_str,
1137        pk_names_eq_str = pk_names_eq_str,
1138        all_column_names_set_str = all_column_names_set_str,
1139        all_column_names_str = all_column_names_str,
1140        all_column_names_insert_str = all_column_names_insert_str,
1141        snowflake_sink_row_id = __ROW_ID,
1142        snowflake_sink_op = __OP,
1143    )
1144}
1145
1146#[cfg(test)]
1147mod tests {
1148    use std::collections::BTreeMap;
1149
1150    use super::*;
1151    use crate::sink::jdbc_jni_client::normalize_sql;
1152
1153    fn base_properties() -> BTreeMap<String, String> {
1154        BTreeMap::from([
1155            ("type".to_owned(), "append-only".to_owned()),
1156            ("jdbc.url".to_owned(), "jdbc:snowflake://account".to_owned()),
1157            ("username".to_owned(), "RW_USER".to_owned()),
1158        ])
1159    }
1160
1161    #[test]
1162    fn test_build_jdbc_props_password() {
1163        let mut props = base_properties();
1164        props.insert("password".to_owned(), "secret".to_owned());
1165        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1166        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1167        assert_eq!(url, "jdbc:snowflake://account");
1168        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1169        assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1170        assert_eq!(map.get("password"), Some(&"secret".to_owned()));
1171        assert!(!map.contains_key("authenticator"));
1172    }
1173
1174    #[test]
1175    fn test_build_jdbc_props_key_pair_file() {
1176        let mut props = base_properties();
1177        props.insert(
1178            "auth.method".to_owned(),
1179            AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
1180        );
1181        props.insert("private_key_file".to_owned(), "/tmp/rsa_key.p8".to_owned());
1182        props.insert("private_key_file_pwd".to_owned(), "dummy".to_owned());
1183        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1184        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1185        assert_eq!(url, "jdbc:snowflake://account");
1186        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1187        assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1188        assert_eq!(
1189            map.get("private_key_file"),
1190            Some(&"/tmp/rsa_key.p8".to_owned())
1191        );
1192        assert_eq!(map.get("private_key_file_pwd"), Some(&"dummy".to_owned()));
1193    }
1194
1195    #[test]
1196    fn test_build_jdbc_props_key_pair_object() {
1197        let mut props = base_properties();
1198        props.insert(
1199            "auth.method".to_owned(),
1200            AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
1201        );
1202        props.insert(
1203            "private_key_pem".to_owned(),
1204            "-----BEGIN PRIVATE KEY-----
1205...
1206-----END PRIVATE KEY-----"
1207                .to_owned(),
1208        );
1209        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1210        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1211        assert_eq!(url, "jdbc:snowflake://account");
1212        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1213        assert_eq!(
1214            map.get("private_key_pem"),
1215            Some(
1216                &"-----BEGIN PRIVATE KEY-----
1217...
1218-----END PRIVATE KEY-----"
1219                    .to_owned()
1220            )
1221        );
1222        assert!(!map.contains_key("private_key_file"));
1223    }
1224
1225    #[test]
1226    fn test_snowflake_sink_commit_coordinator() {
1227        let snowflake_task_context = SnowflakeTaskContext {
1228            task_name: Some("test_task".to_owned()),
1229            cdc_table_name: Some("test_cdc_table".to_owned()),
1230            target_table_name: "test_target_table".to_owned(),
1231            writer_target_interval_seconds: 3600,
1232            warehouse: Some("test_warehouse".to_owned()),
1233            pk_column_names: Some(vec!["v1".to_owned()]),
1234            all_column_names: Some(vec!["v1".to_owned(), "v2".to_owned()]),
1235            database: "test_db".to_owned(),
1236            schema_name: "test_schema".to_owned(),
1237            schema: Schema { fields: vec![] },
1238            stage: None,
1239            pipe_name: None,
1240        };
1241        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1242        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task"
1243WAREHOUSE = test_warehouse
1244SCHEDULE = '3600 SECONDS'
1245AS
1246BEGIN
1247    LET max_row_id STRING;
1248
1249    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1250    FROM "test_db"."test_schema"."test_cdc_table";
1251
1252    MERGE INTO "test_db"."test_schema"."test_target_table" AS target
1253    USING (
1254        SELECT *
1255        FROM (
1256            SELECT *, ROW_NUMBER() OVER (PARTITION BY "v1" ORDER BY "__row_id" DESC) AS dedupe_id
1257            FROM "test_db"."test_schema"."test_cdc_table"
1258            WHERE "__row_id" <= :max_row_id
1259        ) AS subquery
1260        WHERE dedupe_id = 1
1261    ) AS source
1262    ON target."v1" = source."v1"
1263    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1264    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."v1" = source."v1", target."v2" = source."v2"
1265    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("v1", "v2") VALUES (source."v1", source."v2");
1266
1267    DELETE FROM "test_db"."test_schema"."test_cdc_table"
1268    WHERE "__row_id" <= :max_row_id;
1269END;"#;
1270        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1271    }
1272
1273    #[test]
1274    fn test_snowflake_sink_commit_coordinator_multi_pk() {
1275        let snowflake_task_context = SnowflakeTaskContext {
1276            task_name: Some("test_task_multi_pk".to_owned()),
1277            cdc_table_name: Some("cdc_multi_pk".to_owned()),
1278            target_table_name: "target_multi_pk".to_owned(),
1279            writer_target_interval_seconds: 300,
1280            warehouse: Some("multi_pk_warehouse".to_owned()),
1281            pk_column_names: Some(vec!["id1".to_owned(), "id2".to_owned()]),
1282            all_column_names: Some(vec!["id1".to_owned(), "id2".to_owned(), "val".to_owned()]),
1283            database: "test_db".to_owned(),
1284            schema_name: "test_schema".to_owned(),
1285            schema: Schema { fields: vec![] },
1286            stage: None,
1287            pipe_name: None,
1288        };
1289        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1290        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task_multi_pk"
1291WAREHOUSE = multi_pk_warehouse
1292SCHEDULE = '300 SECONDS'
1293AS
1294BEGIN
1295    LET max_row_id STRING;
1296
1297    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1298    FROM "test_db"."test_schema"."cdc_multi_pk";
1299
1300    MERGE INTO "test_db"."test_schema"."target_multi_pk" AS target
1301    USING (
1302        SELECT *
1303        FROM (
1304            SELECT *, ROW_NUMBER() OVER (PARTITION BY "id1", "id2" ORDER BY "__row_id" DESC) AS dedupe_id
1305            FROM "test_db"."test_schema"."cdc_multi_pk"
1306            WHERE "__row_id" <= :max_row_id
1307        ) AS subquery
1308        WHERE dedupe_id = 1
1309    ) AS source
1310    ON target."id1" = source."id1" AND target."id2" = source."id2"
1311    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1312    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."id1" = source."id1", target."id2" = source."id2", target."val" = source."val"
1313    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("id1", "id2", "val") VALUES (source."id1", source."id2", source."val");
1314
1315    DELETE FROM "test_db"."test_schema"."cdc_multi_pk"
1316    WHERE "__row_id" <= :max_row_id;
1317END;"#;
1318        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1319    }
1320}