risingwave_connector/sink/snowflake_redshift/
snowflake.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::num::NonZeroU64;
16use std::collections::BTreeMap;
17
18use anyhow::anyhow;
19use phf::{Set, phf_set};
20use risingwave_common::array::StreamChunk;
21use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
22use risingwave_common::types::DataType;
23use risingwave_pb::connector_service::{SinkMetadata, sink_metadata};
24use sea_orm::DatabaseConnection;
25use serde::Deserialize;
26use serde_with::{DisplayFromStr, serde_as};
27use thiserror_ext::AsReport;
28use tokio::sync::mpsc::UnboundedSender;
29use tonic::async_trait;
30use with_options::WithOptions;
31
32use crate::connector_common::IcebergSinkCompactionUpdate;
33use crate::enforce_secret::EnforceSecret;
34use crate::sink::coordinate::CoordinatedLogSinker;
35use crate::sink::decouple_checkpoint_log_sink::default_commit_checkpoint_interval;
36use crate::sink::file_sink::s3::S3Common;
37use crate::sink::jdbc_jni_client::{self, JdbcJniClient};
38use crate::sink::remote::CoordinatedRemoteSinkWriter;
39use crate::sink::snowflake_redshift::{AugmentedChunk, SnowflakeRedshiftSinkS3Writer};
40use crate::sink::writer::SinkWriter;
41use crate::sink::{
42    Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, Sink, SinkCommitCoordinator,
43    SinkCommittedEpochSubscriber, SinkError, SinkParam, SinkWriterMetrics, SinkWriterParam,
44};
45
46pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2";
47pub const SNOWFLAKE_SINK_ROW_ID: &str = "__row_id";
48pub const SNOWFLAKE_SINK_OP: &str = "__op";
49
50const AUTH_METHOD_PASSWORD: &str = "password";
51const AUTH_METHOD_KEY_PAIR_FILE: &str = "key_pair_file";
52const AUTH_METHOD_KEY_PAIR_OBJECT: &str = "key_pair_object";
53const PROP_AUTH_METHOD: &str = "auth.method";
54
55#[serde_as]
56#[derive(Debug, Clone, Deserialize, WithOptions)]
57pub struct SnowflakeV2Config {
58    #[serde(rename = "type")]
59    pub r#type: String,
60
61    #[serde(rename = "intermediate.table.name")]
62    pub snowflake_cdc_table_name: Option<String>,
63
64    #[serde(rename = "table.name")]
65    pub snowflake_target_table_name: Option<String>,
66
67    #[serde(rename = "database")]
68    pub snowflake_database: Option<String>,
69
70    #[serde(rename = "schema")]
71    pub snowflake_schema: Option<String>,
72
73    #[serde(default = "default_schedule")]
74    #[serde(rename = "write.target.interval.seconds")]
75    #[serde_as(as = "DisplayFromStr")]
76    pub snowflake_schedule_seconds: u64,
77
78    #[serde(rename = "warehouse")]
79    pub snowflake_warehouse: Option<String>,
80
81    #[serde(rename = "jdbc.url")]
82    pub jdbc_url: Option<String>,
83
84    #[serde(rename = "username")]
85    pub username: Option<String>,
86
87    #[serde(rename = "password")]
88    pub password: Option<String>,
89
90    // Authentication method control (password | key_pair_file | key_pair_object)
91    #[serde(rename = "auth.method")]
92    pub auth_method: Option<String>,
93
94    // Key-pair authentication via connection Properties (Option 2: file-based)
95    #[serde(rename = "private_key_file")]
96    pub private_key_file: Option<String>,
97
98    #[serde(rename = "private_key_file_pwd")]
99    pub private_key_file_pwd: Option<String>,
100
101    // Key-pair authentication via connection Properties (Option 1: object-based, PEM content)
102    #[serde(rename = "private_key_pem")]
103    pub private_key_pem: Option<String>,
104
105    /// Commit every n(>0) checkpoints, default is 10.
106    #[serde(default = "default_commit_checkpoint_interval")]
107    #[serde_as(as = "DisplayFromStr")]
108    #[with_option(allow_alter_on_fly)]
109    pub commit_checkpoint_interval: u64,
110
111    /// Enable auto schema change for upsert sink.
112    /// If enabled, the sink will automatically alter the target table to add new columns.
113    #[serde(default)]
114    #[serde(rename = "auto.schema.change")]
115    #[serde_as(as = "DisplayFromStr")]
116    pub auto_schema_change: bool,
117
118    #[serde(default)]
119    #[serde(rename = "create_table_if_not_exists")]
120    #[serde_as(as = "DisplayFromStr")]
121    pub create_table_if_not_exists: bool,
122
123    #[serde(default = "default_with_s3")]
124    #[serde(rename = "with_s3")]
125    #[serde_as(as = "DisplayFromStr")]
126    pub with_s3: bool,
127
128    #[serde(flatten)]
129    pub s3_inner: Option<S3Common>,
130
131    #[serde(rename = "stage")]
132    pub stage: Option<String>,
133}
134
135fn default_schedule() -> u64 {
136    3600 // Default to 1 hour
137}
138
139fn default_with_s3() -> bool {
140    true
141}
142
143impl SnowflakeV2Config {
144    /// Build JDBC Properties for the Snowflake JDBC connection (no URL parameters).
145    /// Returns (`jdbc_url`, `driver_properties`).
146    /// - `driver_properties` are transformed/used by the Java runner and passed to `DriverManager::getConnection(url, props)`
147    ///
148    /// Note: This method assumes the config has been validated by `from_btreemap`.
149    pub fn build_jdbc_connection_properties(&self) -> Result<(String, Vec<(String, String)>)> {
150        let jdbc_url = self
151            .jdbc_url
152            .clone()
153            .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?;
154        let username = self
155            .username
156            .clone()
157            .ok_or(SinkError::Config(anyhow!("username is required")))?;
158
159        let mut connection_properties: Vec<(String, String)> = vec![("user".to_owned(), username)];
160
161        // auth_method is guaranteed to be Some after validation in from_btreemap
162        match self.auth_method.as_deref().unwrap() {
163            AUTH_METHOD_PASSWORD => {
164                // password is guaranteed to exist by from_btreemap validation
165                connection_properties.push(("password".to_owned(), self.password.clone().unwrap()));
166            }
167            AUTH_METHOD_KEY_PAIR_FILE => {
168                // private_key_file is guaranteed to exist by from_btreemap validation
169                connection_properties.push((
170                    "private_key_file".to_owned(),
171                    self.private_key_file.clone().unwrap(),
172                ));
173                if let Some(pwd) = self.private_key_file_pwd.clone() {
174                    connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
175                }
176            }
177            AUTH_METHOD_KEY_PAIR_OBJECT => {
178                connection_properties.push((
179                    PROP_AUTH_METHOD.to_owned(),
180                    AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
181                ));
182                // private_key_pem is guaranteed to exist by from_btreemap validation
183                connection_properties.push((
184                    "private_key_pem".to_owned(),
185                    self.private_key_pem.clone().unwrap(),
186                ));
187                if let Some(pwd) = self.private_key_file_pwd.clone() {
188                    connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
189                }
190            }
191            _ => {
192                // This should never happen since from_btreemap validates auth_method
193                unreachable!(
194                    "Invalid auth_method - should have been caught during config validation"
195                )
196            }
197        }
198
199        Ok((jdbc_url, connection_properties))
200    }
201
202    pub fn from_btreemap(properties: &BTreeMap<String, String>) -> Result<Self> {
203        let mut config =
204            serde_json::from_value::<SnowflakeV2Config>(serde_json::to_value(properties).unwrap())
205                .map_err(|e| SinkError::Config(anyhow!(e)))?;
206        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
207            return Err(SinkError::Config(anyhow!(
208                "`{}` must be {}, or {}",
209                SINK_TYPE_OPTION,
210                SINK_TYPE_APPEND_ONLY,
211                SINK_TYPE_UPSERT
212            )));
213        }
214
215        // Normalize and validate authentication method
216        let has_password = config.password.is_some();
217        let has_file = config.private_key_file.is_some();
218        let has_pem = config.private_key_pem.as_deref().is_some();
219
220        let normalized_auth_method = match config
221            .auth_method
222            .as_deref()
223            .map(|s| s.trim().to_ascii_lowercase())
224        {
225            Some(method) if method == AUTH_METHOD_PASSWORD => {
226                if !has_password {
227                    return Err(SinkError::Config(anyhow!(
228                        "auth.method=password requires `password`"
229                    )));
230                }
231                if has_file || has_pem {
232                    return Err(SinkError::Config(anyhow!(
233                        "auth.method=password must not set `private_key_file`/`private_key_pem`"
234                    )));
235                }
236                AUTH_METHOD_PASSWORD.to_owned()
237            }
238            Some(method) if method == AUTH_METHOD_KEY_PAIR_FILE => {
239                if !has_file {
240                    return Err(SinkError::Config(anyhow!(
241                        "auth.method=key_pair_file requires `private_key_file`"
242                    )));
243                }
244                if has_password {
245                    return Err(SinkError::Config(anyhow!(
246                        "auth.method=key_pair_file must not set `password`"
247                    )));
248                }
249                if has_pem {
250                    return Err(SinkError::Config(anyhow!(
251                        "auth.method=key_pair_file must not set `private_key_pem`"
252                    )));
253                }
254                AUTH_METHOD_KEY_PAIR_FILE.to_owned()
255            }
256            Some(method) if method == AUTH_METHOD_KEY_PAIR_OBJECT => {
257                if !has_pem {
258                    return Err(SinkError::Config(anyhow!(
259                        "auth.method=key_pair_object requires `private_key_pem`"
260                    )));
261                }
262                if has_password {
263                    return Err(SinkError::Config(anyhow!(
264                        "auth.method=key_pair_object must not set `password`"
265                    )));
266                }
267                AUTH_METHOD_KEY_PAIR_OBJECT.to_owned()
268            }
269            Some(other) => {
270                return Err(SinkError::Config(anyhow!(
271                    "invalid auth.method: {} (allowed: password | key_pair_file | key_pair_object)",
272                    other
273                )));
274            }
275            None => {
276                // Infer auth method from supplied fields
277                match (has_password, has_file, has_pem) {
278                    (true, false, false) => AUTH_METHOD_PASSWORD.to_owned(),
279                    (false, true, false) => AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
280                    (false, false, true) => AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
281                    (true, true, _) | (true, _, true) | (false, true, true) => {
282                        return Err(SinkError::Config(anyhow!(
283                            "ambiguous auth: multiple auth options provided; remove one or set `auth.method`"
284                        )));
285                    }
286                    _ => {
287                        return Err(SinkError::Config(anyhow!(
288                            "no authentication configured: set either `password`, or `private_key_file`, or `private_key_pem` (or provide `auth.method`)"
289                        )));
290                    }
291                }
292            }
293        };
294        config.auth_method = Some(normalized_auth_method);
295        Ok(config)
296    }
297
298    pub fn build_snowflake_task_ctx_jdbc_client(
299        &self,
300        is_append_only: bool,
301        schema: &Schema,
302        pk_indices: &Vec<usize>,
303    ) -> Result<Option<(SnowflakeTaskContext, JdbcJniClient)>> {
304        if !self.auto_schema_change && is_append_only && !self.create_table_if_not_exists {
305            // append-only + no auto schema change is not need to create a client
306            return Ok(None);
307        }
308        let target_table_name = self
309            .snowflake_target_table_name
310            .clone()
311            .ok_or(SinkError::Config(anyhow!("table.name is required")))?;
312        let database = self
313            .snowflake_database
314            .clone()
315            .ok_or(SinkError::Config(anyhow!("database is required")))?;
316        let schema_name = self
317            .snowflake_schema
318            .clone()
319            .ok_or(SinkError::Config(anyhow!("schema is required")))?;
320        let mut snowflake_task_ctx = SnowflakeTaskContext {
321            target_table_name: target_table_name.clone(),
322            database,
323            schema_name,
324            schema: schema.clone(),
325            ..Default::default()
326        };
327
328        let (jdbc_url, connection_properties) = self.build_jdbc_connection_properties()?;
329        let client = JdbcJniClient::new_with_props(jdbc_url, connection_properties)?;
330
331        if self.with_s3 {
332            let stage = self
333                .stage
334                .clone()
335                .ok_or(SinkError::Config(anyhow!("stage is required")))?;
336            snowflake_task_ctx.stage = Some(stage);
337            snowflake_task_ctx.pipe_name = Some(format!("{}_pipe", target_table_name));
338        }
339        if !is_append_only {
340            let cdc_table_name = self
341                .snowflake_cdc_table_name
342                .clone()
343                .ok_or(SinkError::Config(anyhow!(
344                    "intermediate.table.name is required"
345                )))?;
346            snowflake_task_ctx.cdc_table_name = Some(cdc_table_name.clone());
347            snowflake_task_ctx.schedule_seconds = self.snowflake_schedule_seconds;
348            snowflake_task_ctx.warehouse = Some(
349                self.snowflake_warehouse
350                    .clone()
351                    .ok_or(SinkError::Config(anyhow!("warehouse is required")))?,
352            );
353            let pk_column_names: Vec<_> = schema
354                .fields
355                .iter()
356                .enumerate()
357                .filter(|(index, _)| pk_indices.contains(index))
358                .map(|(_, field)| field.name.clone())
359                .collect();
360            if pk_column_names.is_empty() {
361                return Err(SinkError::Config(anyhow!(
362                    "Primary key columns not found. Please set the `primary_key` column in the sink properties, or ensure that the sink contains the primary key columns from the upstream."
363                )));
364            }
365            snowflake_task_ctx.pk_column_names = Some(pk_column_names);
366            snowflake_task_ctx.all_column_names = Some(
367                schema
368                    .fields
369                    .iter()
370                    .map(|field| field.name.clone())
371                    .collect(),
372            );
373            snowflake_task_ctx.task_name = Some(format!(
374                "rw_snowflake_sink_from_{cdc_table_name}_to_{target_table_name}"
375            ));
376        }
377        Ok(Some((snowflake_task_ctx, client)))
378    }
379}
380
381impl EnforceSecret for SnowflakeV2Config {
382    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
383        "username",
384        "password",
385        "jdbc.url",
386        // Key-pair authentication secrets
387        "private_key_file_pwd",
388        "private_key_pem",
389    };
390}
391
392#[derive(Clone, Debug)]
393pub struct SnowflakeV2Sink {
394    config: SnowflakeV2Config,
395    schema: Schema,
396    pk_indices: Vec<usize>,
397    is_append_only: bool,
398    param: SinkParam,
399}
400
401impl EnforceSecret for SnowflakeV2Sink {
402    fn enforce_secret<'a>(
403        prop_iter: impl Iterator<Item = &'a str>,
404    ) -> crate::sink::ConnectorResult<()> {
405        for prop in prop_iter {
406            SnowflakeV2Config::enforce_one(prop)?;
407        }
408        Ok(())
409    }
410}
411
412impl TryFrom<SinkParam> for SnowflakeV2Sink {
413    type Error = SinkError;
414
415    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
416        let schema = param.schema();
417        let config = SnowflakeV2Config::from_btreemap(&param.properties)?;
418        let is_append_only = param.sink_type.is_append_only();
419        let pk_indices = param.downstream_pk_or_empty();
420        Ok(Self {
421            config,
422            schema,
423            pk_indices,
424            is_append_only,
425            param,
426        })
427    }
428}
429
430impl Sink for SnowflakeV2Sink {
431    type Coordinator = SnowflakeSinkCommitter;
432    type LogSinker = CoordinatedLogSinker<SnowflakeSinkWriter>;
433
434    const SINK_NAME: &'static str = SNOWFLAKE_SINK_V2;
435
436    async fn validate(&self) -> Result<()> {
437        risingwave_common::license::Feature::SnowflakeSink
438            .check_available()
439            .map_err(|e| anyhow::anyhow!(e))?;
440        if let Some((snowflake_task_ctx, client)) =
441            self.config.build_snowflake_task_ctx_jdbc_client(
442                self.is_append_only,
443                &self.schema,
444                &self.pk_indices,
445            )?
446        {
447            let client = SnowflakeJniClient::new(client, snowflake_task_ctx);
448            client.execute_create_table().await?;
449        }
450
451        Ok(())
452    }
453
454    fn support_schema_change() -> bool {
455        true
456    }
457
458    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
459        SnowflakeV2Config::from_btreemap(config)?;
460        Ok(())
461    }
462
463    async fn new_log_sinker(
464        &self,
465        writer_param: crate::sink::SinkWriterParam,
466    ) -> Result<Self::LogSinker> {
467        let writer = SnowflakeSinkWriter::new(
468            self.config.clone(),
469            self.is_append_only,
470            writer_param.clone(),
471            self.param.clone(),
472        )
473        .await?;
474
475        let commit_checkpoint_interval =
476            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
477                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
478            );
479
480        CoordinatedLogSinker::new(
481            &writer_param,
482            self.param.clone(),
483            writer,
484            commit_checkpoint_interval,
485        )
486        .await
487    }
488
489    fn is_coordinated_sink(&self) -> bool {
490        true
491    }
492
493    async fn new_coordinator(
494        &self,
495        _db: DatabaseConnection,
496        _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
497    ) -> Result<Self::Coordinator> {
498        let coordinator = SnowflakeSinkCommitter::new(
499            self.config.clone(),
500            &self.schema,
501            &self.pk_indices,
502            self.is_append_only,
503        )?;
504        Ok(coordinator)
505    }
506}
507
508pub enum SnowflakeSinkWriter {
509    S3(SnowflakeRedshiftSinkS3Writer),
510    Jdbc(SnowflakeSinkJdbcWriter),
511}
512
513impl SnowflakeSinkWriter {
514    pub async fn new(
515        config: SnowflakeV2Config,
516        is_append_only: bool,
517        writer_param: SinkWriterParam,
518        param: SinkParam,
519    ) -> Result<Self> {
520        let schema = param.schema();
521        if config.with_s3 {
522            let executor_id = writer_param.executor_id;
523            let s3_writer = SnowflakeRedshiftSinkS3Writer::new(
524                config.s3_inner.ok_or_else(|| {
525                    SinkError::Config(anyhow!(
526                        "S3 configuration is required for Snowflake S3 sink"
527                    ))
528                })?,
529                schema,
530                is_append_only,
531                executor_id,
532                config.snowflake_target_table_name,
533            )?;
534            Ok(Self::S3(s3_writer))
535        } else {
536            let jdbc_writer =
537                SnowflakeSinkJdbcWriter::new(config, is_append_only, writer_param, param).await?;
538            Ok(Self::Jdbc(jdbc_writer))
539        }
540    }
541}
542
543#[async_trait]
544impl SinkWriter for SnowflakeSinkWriter {
545    type CommitMetadata = Option<SinkMetadata>;
546
547    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
548        match self {
549            Self::S3(writer) => writer.begin_epoch(epoch),
550            Self::Jdbc(writer) => writer.begin_epoch(epoch).await,
551        }
552    }
553
554    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
555        match self {
556            Self::S3(writer) => writer.write_batch(chunk).await,
557            Self::Jdbc(writer) => writer.write_batch(chunk).await,
558        }
559    }
560
561    async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
562        match self {
563            Self::S3(writer) => {
564                writer.barrier(is_checkpoint).await?;
565            }
566            Self::Jdbc(writer) => {
567                writer.barrier(is_checkpoint).await?;
568            }
569        }
570        Ok(Some(SinkMetadata {
571            metadata: Some(sink_metadata::Metadata::Serialized(
572                risingwave_pb::connector_service::sink_metadata::SerializedMetadata {
573                    metadata: vec![],
574                },
575            )),
576        }))
577    }
578
579    async fn abort(&mut self) -> Result<()> {
580        if let Self::Jdbc(writer) = self {
581            writer.abort().await
582        } else {
583            Ok(())
584        }
585    }
586}
587
588pub struct SnowflakeSinkJdbcWriter {
589    augmented_row: AugmentedChunk,
590    jdbc_sink_writer: CoordinatedRemoteSinkWriter,
591}
592
593impl SnowflakeSinkJdbcWriter {
594    pub async fn new(
595        config: SnowflakeV2Config,
596        is_append_only: bool,
597        writer_param: SinkWriterParam,
598        mut param: SinkParam,
599    ) -> Result<Self> {
600        let metrics = SinkWriterMetrics::new(&writer_param);
601        let properties = &param.properties;
602        let column_descs = &mut param.columns;
603        let full_table_name = if is_append_only {
604            format!(
605                r#""{}"."{}"."{}""#,
606                config.snowflake_database.clone().unwrap_or_default(),
607                config.snowflake_schema.clone().unwrap_or_default(),
608                config
609                    .snowflake_target_table_name
610                    .clone()
611                    .unwrap_or_default()
612            )
613        } else {
614            let max_column_id = column_descs
615                .iter()
616                .map(|column| column.column_id.get_id())
617                .max()
618                .unwrap_or(0);
619            (*column_descs).push(ColumnDesc::named(
620                SNOWFLAKE_SINK_ROW_ID,
621                ColumnId::new(max_column_id + 1),
622                DataType::Varchar,
623            ));
624            (*column_descs).push(ColumnDesc::named(
625                SNOWFLAKE_SINK_OP,
626                ColumnId::new(max_column_id + 2),
627                DataType::Int32,
628            ));
629            format!(
630                r#""{}"."{}"."{}""#,
631                config.snowflake_database.clone().unwrap_or_default(),
632                config.snowflake_schema.clone().unwrap_or_default(),
633                config.snowflake_cdc_table_name.clone().unwrap_or_default()
634            )
635        };
636        let mut new_properties = BTreeMap::from([
637            ("table.name".to_owned(), full_table_name),
638            ("connector".to_owned(), "snowflake_v2".to_owned()),
639            (
640                "jdbc.url".to_owned(),
641                config.jdbc_url.clone().unwrap_or_default(),
642            ),
643            ("type".to_owned(), "append-only".to_owned()),
644            (
645                "primary_key".to_owned(),
646                properties.get("primary_key").cloned().unwrap_or_default(),
647            ),
648            (
649                "schema.name".to_owned(),
650                config.snowflake_schema.clone().unwrap_or_default(),
651            ),
652            (
653                "database.name".to_owned(),
654                config.snowflake_database.clone().unwrap_or_default(),
655            ),
656        ]);
657
658        // Reuse build_jdbc_connection_properties to get driver properties (auth, user, etc.)
659        let (_jdbc_url, connection_properties) = config.build_jdbc_connection_properties()?;
660        for (key, value) in connection_properties {
661            new_properties.insert(key, value);
662        }
663
664        param.properties = new_properties;
665
666        let jdbc_sink_writer =
667            CoordinatedRemoteSinkWriter::new(param.clone(), metrics.clone()).await?;
668        Ok(Self {
669            augmented_row: AugmentedChunk::new(0, is_append_only),
670            jdbc_sink_writer,
671        })
672    }
673}
674
675impl SnowflakeSinkJdbcWriter {
676    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
677        self.augmented_row.reset_epoch(epoch);
678        self.jdbc_sink_writer.begin_epoch(epoch).await?;
679        Ok(())
680    }
681
682    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
683        let chunk = self.augmented_row.augmented_chunk(chunk)?;
684        self.jdbc_sink_writer.write_batch(chunk).await?;
685        Ok(())
686    }
687
688    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
689        self.jdbc_sink_writer.barrier(is_checkpoint).await?;
690        Ok(())
691    }
692
693    async fn abort(&mut self) -> Result<()> {
694        // TODO: abort should clean up all the data written in this epoch.
695        self.jdbc_sink_writer.abort().await?;
696        Ok(())
697    }
698}
699
700#[derive(Default)]
701pub struct SnowflakeTaskContext {
702    // required for task creation
703    pub target_table_name: String,
704    pub database: String,
705    pub schema_name: String,
706    pub schema: Schema,
707
708    // only upsert
709    pub task_name: Option<String>,
710    pub cdc_table_name: Option<String>,
711    pub schedule_seconds: u64,
712    pub warehouse: Option<String>,
713    pub pk_column_names: Option<Vec<String>>,
714    pub all_column_names: Option<Vec<String>>,
715
716    // only s3 writer
717    pub stage: Option<String>,
718    pub pipe_name: Option<String>,
719}
720pub struct SnowflakeSinkCommitter {
721    client: Option<SnowflakeJniClient>,
722}
723
724impl SnowflakeSinkCommitter {
725    pub fn new(
726        config: SnowflakeV2Config,
727        schema: &Schema,
728        pk_indices: &Vec<usize>,
729        is_append_only: bool,
730    ) -> Result<Self> {
731        let client = if let Some((snowflake_task_ctx, client)) =
732            config.build_snowflake_task_ctx_jdbc_client(is_append_only, schema, pk_indices)?
733        {
734            Some(SnowflakeJniClient::new(client, snowflake_task_ctx))
735        } else {
736            None
737        };
738        Ok(Self { client })
739    }
740}
741
742#[async_trait]
743impl SinkCommitCoordinator for SnowflakeSinkCommitter {
744    async fn init(&mut self, _subscriber: SinkCommittedEpochSubscriber) -> Result<Option<u64>> {
745        if let Some(client) = &self.client {
746            // Todo: move this to validate
747            client.execute_create_pipe().await?;
748            client.execute_create_merge_into_task().await?;
749        }
750        Ok(None)
751    }
752
753    async fn commit(
754        &mut self,
755        _epoch: u64,
756        _metadata: Vec<SinkMetadata>,
757        add_columns: Option<Vec<Field>>,
758    ) -> Result<()> {
759        let client = self.client.as_mut().ok_or_else(|| {
760            SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
761        })?;
762        client.execute_flush_pipe().await?;
763
764        if let Some(add_columns) = add_columns {
765            client
766                .execute_alter_add_columns(
767                    &add_columns
768                        .iter()
769                        .map(|f| (f.name.clone(), f.data_type.to_string()))
770                        .collect::<Vec<_>>(),
771                )
772                .await?;
773        }
774        Ok(())
775    }
776}
777
778impl Drop for SnowflakeSinkCommitter {
779    fn drop(&mut self) {
780        if let Some(client) = self.client.take() {
781            tokio::spawn(async move {
782                client.execute_drop_task().await.ok();
783            });
784        }
785    }
786}
787
788pub struct SnowflakeJniClient {
789    jdbc_client: JdbcJniClient,
790    snowflake_task_context: SnowflakeTaskContext,
791}
792
793impl SnowflakeJniClient {
794    pub fn new(jdbc_client: JdbcJniClient, snowflake_task_context: SnowflakeTaskContext) -> Self {
795        Self {
796            jdbc_client,
797            snowflake_task_context,
798        }
799    }
800
801    pub async fn execute_alter_add_columns(
802        &mut self,
803        columns: &Vec<(String, String)>,
804    ) -> Result<()> {
805        self.execute_drop_task().await?;
806        if let Some(names) = self.snowflake_task_context.all_column_names.as_mut() {
807            names.extend(columns.iter().map(|(name, _)| name.clone()));
808        }
809        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
810            let alter_add_column_cdc_table_sql = build_alter_add_column_sql(
811                cdc_table_name,
812                &self.snowflake_task_context.database,
813                &self.snowflake_task_context.schema_name,
814                columns,
815            );
816            self.jdbc_client
817                .execute_sql_sync(vec![alter_add_column_cdc_table_sql])
818                .await?;
819        }
820
821        let alter_add_column_target_table_sql = build_alter_add_column_sql(
822            &self.snowflake_task_context.target_table_name,
823            &self.snowflake_task_context.database,
824            &self.snowflake_task_context.schema_name,
825            columns,
826        );
827        self.jdbc_client
828            .execute_sql_sync(vec![alter_add_column_target_table_sql])
829            .await?;
830
831        self.execute_create_merge_into_task().await?;
832        Ok(())
833    }
834
835    pub async fn execute_create_merge_into_task(&self) -> Result<()> {
836        if self.snowflake_task_context.task_name.is_some() {
837            let create_task_sql = build_create_merge_into_task_sql(&self.snowflake_task_context);
838            let start_task_sql = build_start_task_sql(&self.snowflake_task_context);
839            self.jdbc_client
840                .execute_sql_sync(vec![create_task_sql])
841                .await?;
842            self.jdbc_client
843                .execute_sql_sync(vec![start_task_sql])
844                .await?;
845        }
846        Ok(())
847    }
848
849    pub async fn execute_drop_task(&self) -> Result<()> {
850        if self.snowflake_task_context.task_name.is_some() {
851            let sql = build_drop_task_sql(&self.snowflake_task_context);
852            if let Err(e) = self.jdbc_client.execute_sql_sync(vec![sql]).await {
853                tracing::error!(
854                    "Failed to drop Snowflake sink task {:?}: {:?}",
855                    self.snowflake_task_context.task_name,
856                    e.as_report()
857                );
858            } else {
859                tracing::info!(
860                    "Snowflake sink task {:?} dropped",
861                    self.snowflake_task_context.task_name
862                );
863            }
864        }
865        Ok(())
866    }
867
868    pub async fn execute_create_table(&self) -> Result<()> {
869        // create target table
870        let create_target_table_sql = build_create_table_sql(
871            &self.snowflake_task_context.target_table_name,
872            &self.snowflake_task_context.database,
873            &self.snowflake_task_context.schema_name,
874            &self.snowflake_task_context.schema,
875            false,
876        )?;
877        self.jdbc_client
878            .execute_sql_sync(vec![create_target_table_sql])
879            .await?;
880        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
881            let create_cdc_table_sql = build_create_table_sql(
882                cdc_table_name,
883                &self.snowflake_task_context.database,
884                &self.snowflake_task_context.schema_name,
885                &self.snowflake_task_context.schema,
886                true,
887            )?;
888            self.jdbc_client
889                .execute_sql_sync(vec![create_cdc_table_sql])
890                .await?;
891        }
892        Ok(())
893    }
894
895    pub async fn execute_create_pipe(&self) -> Result<()> {
896        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
897            let table_name =
898                if let Some(table_name) = self.snowflake_task_context.cdc_table_name.as_ref() {
899                    table_name
900                } else {
901                    &self.snowflake_task_context.target_table_name
902                };
903            let create_pipe_sql = build_create_pipe_sql(
904                table_name,
905                &self.snowflake_task_context.database,
906                &self.snowflake_task_context.schema_name,
907                self.snowflake_task_context.stage.as_ref().ok_or_else(|| {
908                    SinkError::Config(anyhow!("snowflake.stage is required for S3 writer"))
909                })?,
910                pipe_name,
911                &self.snowflake_task_context.target_table_name,
912            );
913            self.jdbc_client
914                .execute_sql_sync(vec![create_pipe_sql])
915                .await?;
916        }
917        Ok(())
918    }
919
920    pub async fn execute_flush_pipe(&self) -> Result<()> {
921        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
922            let flush_pipe_sql = build_flush_pipe_sql(
923                &self.snowflake_task_context.database,
924                &self.snowflake_task_context.schema_name,
925                pipe_name,
926            );
927            self.jdbc_client
928                .execute_sql_sync(vec![flush_pipe_sql])
929                .await?;
930        }
931        Ok(())
932    }
933}
934
935fn build_create_table_sql(
936    table_name: &str,
937    database: &str,
938    schema_name: &str,
939    schema: &Schema,
940    need_op_and_row_id: bool,
941) -> Result<String> {
942    let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema_name, table_name);
943    let mut columns: Vec<String> = schema
944        .fields
945        .iter()
946        .map(|field| {
947            let data_type = convert_snowflake_data_type(&field.data_type)?;
948            Ok(format!(r#""{}" {}"#, field.name, data_type))
949        })
950        .collect::<Result<Vec<String>>>()?;
951    if need_op_and_row_id {
952        columns.push(format!(r#""{}" STRING"#, SNOWFLAKE_SINK_ROW_ID));
953        columns.push(format!(r#""{}" INT"#, SNOWFLAKE_SINK_OP));
954    }
955    let columns_str = columns.join(", ");
956    Ok(format!(
957        "CREATE TABLE IF NOT EXISTS {} ({}) ENABLE_SCHEMA_EVOLUTION  = true",
958        full_table_name, columns_str
959    ))
960}
961
962fn convert_snowflake_data_type(data_type: &DataType) -> Result<String> {
963    let data_type = match data_type {
964        DataType::Int16 => "SMALLINT".to_owned(),
965        DataType::Int32 => "INTEGER".to_owned(),
966        DataType::Int64 => "BIGINT".to_owned(),
967        DataType::Float32 => "FLOAT4".to_owned(),
968        DataType::Float64 => "FLOAT8".to_owned(),
969        DataType::Boolean => "BOOLEAN".to_owned(),
970        DataType::Varchar => "STRING".to_owned(),
971        DataType::Date => "DATE".to_owned(),
972        DataType::Timestamp => "TIMESTAMP".to_owned(),
973        DataType::Timestamptz => "TIMESTAMP_TZ".to_owned(),
974        DataType::Jsonb => "STRING".to_owned(),
975        DataType::Decimal => "DECIMAL".to_owned(),
976        DataType::Bytea => "BINARY".to_owned(),
977        DataType::Time => "TIME".to_owned(),
978        _ => {
979            return Err(SinkError::Config(anyhow!(
980                "Dont support auto create table for datatype: {}",
981                data_type
982            )));
983        }
984    };
985    Ok(data_type)
986}
987
988fn build_create_pipe_sql(
989    table_name: &str,
990    database: &str,
991    schema: &str,
992    stage: &str,
993    pipe_name: &str,
994    target_table_name: &str,
995) -> String {
996    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
997    let stage = format!(
998        r#""{}"."{}"."{}"/{}"#,
999        database, schema, stage, target_table_name
1000    );
1001    let table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
1002    format!(
1003        "CREATE OR REPLACE PIPE {} AUTO_INGEST = FALSE AS COPY INTO {} FROM @{} MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE FILE_FORMAT = (type = 'JSON');",
1004        pipe_name, table_name, stage
1005    )
1006}
1007
1008fn build_flush_pipe_sql(database: &str, schema: &str, pipe_name: &str) -> String {
1009    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
1010    format!("ALTER PIPE {} REFRESH;", pipe_name,)
1011}
1012
1013fn build_alter_add_column_sql(
1014    table_name: &str,
1015    database: &str,
1016    schema: &str,
1017    columns: &Vec<(String, String)>,
1018) -> String {
1019    let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
1020    jdbc_jni_client::build_alter_add_column_sql(&full_table_name, columns, true)
1021}
1022
1023fn build_start_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1024    let SnowflakeTaskContext {
1025        task_name,
1026        database,
1027        schema_name: schema,
1028        ..
1029    } = snowflake_task_context;
1030    let full_task_name = format!(
1031        r#""{}"."{}"."{}""#,
1032        database,
1033        schema,
1034        task_name.as_ref().unwrap()
1035    );
1036    format!("ALTER TASK {} RESUME", full_task_name)
1037}
1038
1039fn build_drop_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1040    let SnowflakeTaskContext {
1041        task_name,
1042        database,
1043        schema_name: schema,
1044        ..
1045    } = snowflake_task_context;
1046    let full_task_name = format!(
1047        r#""{}"."{}"."{}""#,
1048        database,
1049        schema,
1050        task_name.as_ref().unwrap()
1051    );
1052    format!("DROP TASK IF EXISTS {}", full_task_name)
1053}
1054
1055fn build_create_merge_into_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1056    let SnowflakeTaskContext {
1057        task_name,
1058        cdc_table_name,
1059        target_table_name,
1060        schedule_seconds,
1061        warehouse,
1062        pk_column_names,
1063        all_column_names,
1064        database,
1065        schema_name,
1066        ..
1067    } = snowflake_task_context;
1068    let full_task_name = format!(
1069        r#""{}"."{}"."{}""#,
1070        database,
1071        schema_name,
1072        task_name.as_ref().unwrap()
1073    );
1074    let full_cdc_table_name = format!(
1075        r#""{}"."{}"."{}""#,
1076        database,
1077        schema_name,
1078        cdc_table_name.as_ref().unwrap()
1079    );
1080    let full_target_table_name = format!(
1081        r#""{}"."{}"."{}""#,
1082        database, schema_name, target_table_name
1083    );
1084
1085    let pk_names_str = pk_column_names
1086        .as_ref()
1087        .unwrap()
1088        .iter()
1089        .map(|name| format!(r#""{}""#, name))
1090        .collect::<Vec<String>>()
1091        .join(", ");
1092    let pk_names_eq_str = pk_column_names
1093        .as_ref()
1094        .unwrap()
1095        .iter()
1096        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1097        .collect::<Vec<String>>()
1098        .join(" AND ");
1099    let all_column_names_set_str = all_column_names
1100        .as_ref()
1101        .unwrap()
1102        .iter()
1103        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1104        .collect::<Vec<String>>()
1105        .join(", ");
1106    let all_column_names_str = all_column_names
1107        .as_ref()
1108        .unwrap()
1109        .iter()
1110        .map(|name| format!(r#""{}""#, name))
1111        .collect::<Vec<String>>()
1112        .join(", ");
1113    let all_column_names_insert_str = all_column_names
1114        .as_ref()
1115        .unwrap()
1116        .iter()
1117        .map(|name| format!(r#"source."{}""#, name))
1118        .collect::<Vec<String>>()
1119        .join(", ");
1120
1121    format!(
1122        r#"CREATE OR REPLACE TASK {task_name}
1123WAREHOUSE = {warehouse}
1124SCHEDULE = '{schedule_seconds} SECONDS'
1125AS
1126BEGIN
1127    LET max_row_id STRING;
1128
1129    SELECT COALESCE(MAX("{snowflake_sink_row_id}"), '0') INTO :max_row_id
1130    FROM {cdc_table_name};
1131
1132    MERGE INTO {target_table_name} AS target
1133    USING (
1134        SELECT *
1135        FROM (
1136            SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_names_str} ORDER BY "{snowflake_sink_row_id}" DESC) AS dedupe_id
1137            FROM {cdc_table_name}
1138            WHERE "{snowflake_sink_row_id}" <= :max_row_id
1139        ) AS subquery
1140        WHERE dedupe_id = 1
1141    ) AS source
1142    ON {pk_names_eq_str}
1143    WHEN MATCHED AND source."{snowflake_sink_op}" IN (2, 4) THEN DELETE
1144    WHEN MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN UPDATE SET {all_column_names_set_str}
1145    WHEN NOT MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN INSERT ({all_column_names_str}) VALUES ({all_column_names_insert_str});
1146
1147    DELETE FROM {cdc_table_name}
1148    WHERE "{snowflake_sink_row_id}" <= :max_row_id;
1149END;"#,
1150        task_name = full_task_name,
1151        warehouse = warehouse.as_ref().unwrap(),
1152        schedule_seconds = schedule_seconds,
1153        cdc_table_name = full_cdc_table_name,
1154        target_table_name = full_target_table_name,
1155        pk_names_str = pk_names_str,
1156        pk_names_eq_str = pk_names_eq_str,
1157        all_column_names_set_str = all_column_names_set_str,
1158        all_column_names_str = all_column_names_str,
1159        all_column_names_insert_str = all_column_names_insert_str,
1160        snowflake_sink_row_id = SNOWFLAKE_SINK_ROW_ID,
1161        snowflake_sink_op = SNOWFLAKE_SINK_OP,
1162    )
1163}
1164
1165#[cfg(test)]
1166mod tests {
1167    use std::collections::BTreeMap;
1168
1169    use super::*;
1170    use crate::sink::jdbc_jni_client::normalize_sql;
1171
1172    fn base_properties() -> BTreeMap<String, String> {
1173        BTreeMap::from([
1174            ("type".to_owned(), "append-only".to_owned()),
1175            ("jdbc.url".to_owned(), "jdbc:snowflake://account".to_owned()),
1176            ("username".to_owned(), "RW_USER".to_owned()),
1177        ])
1178    }
1179
1180    #[test]
1181    fn test_build_jdbc_props_password() {
1182        let mut props = base_properties();
1183        props.insert("password".to_owned(), "secret".to_owned());
1184        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1185        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1186        assert_eq!(url, "jdbc:snowflake://account");
1187        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1188        assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1189        assert_eq!(map.get("password"), Some(&"secret".to_owned()));
1190        assert!(!map.contains_key("authenticator"));
1191    }
1192
1193    #[test]
1194    fn test_build_jdbc_props_key_pair_file() {
1195        let mut props = base_properties();
1196        props.insert(
1197            "auth.method".to_owned(),
1198            AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
1199        );
1200        props.insert("private_key_file".to_owned(), "/tmp/rsa_key.p8".to_owned());
1201        props.insert("private_key_file_pwd".to_owned(), "dummy".to_owned());
1202        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1203        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1204        assert_eq!(url, "jdbc:snowflake://account");
1205        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1206        assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1207        assert_eq!(
1208            map.get("private_key_file"),
1209            Some(&"/tmp/rsa_key.p8".to_owned())
1210        );
1211        assert_eq!(map.get("private_key_file_pwd"), Some(&"dummy".to_owned()));
1212    }
1213
1214    #[test]
1215    fn test_build_jdbc_props_key_pair_object() {
1216        let mut props = base_properties();
1217        props.insert(
1218            "auth.method".to_owned(),
1219            AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
1220        );
1221        props.insert(
1222            "private_key_pem".to_owned(),
1223            "-----BEGIN PRIVATE KEY-----
1224...
1225-----END PRIVATE KEY-----"
1226                .to_owned(),
1227        );
1228        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1229        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1230        assert_eq!(url, "jdbc:snowflake://account");
1231        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1232        assert_eq!(
1233            map.get("private_key_pem"),
1234            Some(
1235                &"-----BEGIN PRIVATE KEY-----
1236...
1237-----END PRIVATE KEY-----"
1238                    .to_owned()
1239            )
1240        );
1241        assert!(!map.contains_key("private_key_file"));
1242    }
1243
1244    #[test]
1245    fn test_snowflake_sink_commit_coordinator() {
1246        let snowflake_task_context = SnowflakeTaskContext {
1247            task_name: Some("test_task".to_owned()),
1248            cdc_table_name: Some("test_cdc_table".to_owned()),
1249            target_table_name: "test_target_table".to_owned(),
1250            schedule_seconds: 3600,
1251            warehouse: Some("test_warehouse".to_owned()),
1252            pk_column_names: Some(vec!["v1".to_owned()]),
1253            all_column_names: Some(vec!["v1".to_owned(), "v2".to_owned()]),
1254            database: "test_db".to_owned(),
1255            schema_name: "test_schema".to_owned(),
1256            schema: Schema { fields: vec![] },
1257            stage: None,
1258            pipe_name: None,
1259        };
1260        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1261        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task"
1262WAREHOUSE = test_warehouse
1263SCHEDULE = '3600 SECONDS'
1264AS
1265BEGIN
1266    LET max_row_id STRING;
1267
1268    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1269    FROM "test_db"."test_schema"."test_cdc_table";
1270
1271    MERGE INTO "test_db"."test_schema"."test_target_table" AS target
1272    USING (
1273        SELECT *
1274        FROM (
1275            SELECT *, ROW_NUMBER() OVER (PARTITION BY "v1" ORDER BY "__row_id" DESC) AS dedupe_id
1276            FROM "test_db"."test_schema"."test_cdc_table"
1277            WHERE "__row_id" <= :max_row_id
1278        ) AS subquery
1279        WHERE dedupe_id = 1
1280    ) AS source
1281    ON target."v1" = source."v1"
1282    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1283    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."v1" = source."v1", target."v2" = source."v2"
1284    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("v1", "v2") VALUES (source."v1", source."v2");
1285
1286    DELETE FROM "test_db"."test_schema"."test_cdc_table"
1287    WHERE "__row_id" <= :max_row_id;
1288END;"#;
1289        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1290    }
1291
1292    #[test]
1293    fn test_snowflake_sink_commit_coordinator_multi_pk() {
1294        let snowflake_task_context = SnowflakeTaskContext {
1295            task_name: Some("test_task_multi_pk".to_owned()),
1296            cdc_table_name: Some("cdc_multi_pk".to_owned()),
1297            target_table_name: "target_multi_pk".to_owned(),
1298            schedule_seconds: 300,
1299            warehouse: Some("multi_pk_warehouse".to_owned()),
1300            pk_column_names: Some(vec!["id1".to_owned(), "id2".to_owned()]),
1301            all_column_names: Some(vec!["id1".to_owned(), "id2".to_owned(), "val".to_owned()]),
1302            database: "test_db".to_owned(),
1303            schema_name: "test_schema".to_owned(),
1304            schema: Schema { fields: vec![] },
1305            stage: None,
1306            pipe_name: None,
1307        };
1308        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1309        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task_multi_pk"
1310WAREHOUSE = multi_pk_warehouse
1311SCHEDULE = '300 SECONDS'
1312AS
1313BEGIN
1314    LET max_row_id STRING;
1315
1316    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1317    FROM "test_db"."test_schema"."cdc_multi_pk";
1318
1319    MERGE INTO "test_db"."test_schema"."target_multi_pk" AS target
1320    USING (
1321        SELECT *
1322        FROM (
1323            SELECT *, ROW_NUMBER() OVER (PARTITION BY "id1", "id2" ORDER BY "__row_id" DESC) AS dedupe_id
1324            FROM "test_db"."test_schema"."cdc_multi_pk"
1325            WHERE "__row_id" <= :max_row_id
1326        ) AS subquery
1327        WHERE dedupe_id = 1
1328    ) AS source
1329    ON target."id1" = source."id1" AND target."id2" = source."id2"
1330    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1331    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."id1" = source."id1", target."id2" = source."id2", target."val" = source."val"
1332    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("id1", "id2", "val") VALUES (source."id1", source."id2", source."val");
1333
1334    DELETE FROM "test_db"."test_schema"."cdc_multi_pk"
1335    WHERE "__row_id" <= :max_row_id;
1336END;"#;
1337        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1338    }
1339}