risingwave_connector/sink/snowflake_redshift/
snowflake.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::num::NonZeroU64;
16use std::collections::BTreeMap;
17
18use anyhow::anyhow;
19use itertools::Itertools;
20use phf::{Set, phf_set};
21use risingwave_common::array::StreamChunk;
22use risingwave_common::catalog::{ColumnDesc, ColumnId, Schema};
23use risingwave_common::types::DataType;
24use risingwave_pb::connector_service::{SinkMetadata, sink_metadata};
25use risingwave_pb::stream_plan::PbSinkSchemaChange;
26use serde::Deserialize;
27use serde_with::{DisplayFromStr, serde_as};
28use thiserror_ext::AsReport;
29use tokio::sync::mpsc::UnboundedSender;
30use tonic::async_trait;
31use with_options::WithOptions;
32
33use crate::connector_common::IcebergSinkCompactionUpdate;
34use crate::enforce_secret::EnforceSecret;
35use crate::sink::coordinate::CoordinatedLogSinker;
36use crate::sink::decouple_checkpoint_log_sink::default_commit_checkpoint_interval;
37use crate::sink::file_sink::s3::S3Common;
38use crate::sink::jdbc_jni_client::{self, JdbcJniClient};
39use crate::sink::remote::CoordinatedRemoteSinkWriter;
40use crate::sink::snowflake_redshift::{AugmentedChunk, SnowflakeRedshiftSinkS3Writer};
41use crate::sink::writer::SinkWriter;
42use crate::sink::{
43    Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT,
44    SinglePhaseCommitCoordinator, Sink, SinkCommitCoordinator, SinkError, SinkParam,
45    SinkWriterMetrics, SinkWriterParam,
46};
47
48pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2";
49pub const SNOWFLAKE_SINK_ROW_ID: &str = "__row_id";
50pub const SNOWFLAKE_SINK_OP: &str = "__op";
51
52const AUTH_METHOD_PASSWORD: &str = "password";
53const AUTH_METHOD_KEY_PAIR_FILE: &str = "key_pair_file";
54const AUTH_METHOD_KEY_PAIR_OBJECT: &str = "key_pair_object";
55const PROP_AUTH_METHOD: &str = "auth.method";
56
57#[serde_as]
58#[derive(Debug, Clone, Deserialize, WithOptions)]
59pub struct SnowflakeV2Config {
60    #[serde(rename = "type")]
61    pub r#type: String,
62
63    #[serde(rename = "intermediate.table.name")]
64    pub snowflake_cdc_table_name: Option<String>,
65
66    #[serde(rename = "table.name")]
67    pub snowflake_target_table_name: Option<String>,
68
69    #[serde(rename = "database")]
70    pub snowflake_database: Option<String>,
71
72    #[serde(rename = "schema")]
73    pub snowflake_schema: Option<String>,
74
75    #[serde(default = "default_schedule")]
76    #[serde(rename = "write.target.interval.seconds")]
77    #[serde_as(as = "DisplayFromStr")]
78    pub snowflake_schedule_seconds: u64,
79
80    #[serde(rename = "warehouse")]
81    pub snowflake_warehouse: Option<String>,
82
83    #[serde(rename = "jdbc.url")]
84    pub jdbc_url: Option<String>,
85
86    #[serde(rename = "username")]
87    pub username: Option<String>,
88
89    #[serde(rename = "password")]
90    pub password: Option<String>,
91
92    // Authentication method control (password | key_pair_file | key_pair_object)
93    #[serde(rename = "auth.method")]
94    pub auth_method: Option<String>,
95
96    // Key-pair authentication via connection Properties (Option 2: file-based)
97    #[serde(rename = "private_key_file")]
98    pub private_key_file: Option<String>,
99
100    #[serde(rename = "private_key_file_pwd")]
101    pub private_key_file_pwd: Option<String>,
102
103    // Key-pair authentication via connection Properties (Option 1: object-based, PEM content)
104    #[serde(rename = "private_key_pem")]
105    pub private_key_pem: Option<String>,
106
107    /// Commit every n(>0) checkpoints, default is 10.
108    #[serde(default = "default_commit_checkpoint_interval")]
109    #[serde_as(as = "DisplayFromStr")]
110    #[with_option(allow_alter_on_fly)]
111    pub commit_checkpoint_interval: u64,
112
113    /// Enable auto schema change for upsert sink.
114    /// If enabled, the sink will automatically alter the target table to add new columns.
115    #[serde(default)]
116    #[serde(rename = "auto.schema.change")]
117    #[serde_as(as = "DisplayFromStr")]
118    pub auto_schema_change: bool,
119
120    #[serde(default)]
121    #[serde(rename = "create_table_if_not_exists")]
122    #[serde_as(as = "DisplayFromStr")]
123    pub create_table_if_not_exists: bool,
124
125    #[serde(default = "default_with_s3")]
126    #[serde(rename = "with_s3")]
127    #[serde_as(as = "DisplayFromStr")]
128    pub with_s3: bool,
129
130    #[serde(flatten)]
131    pub s3_inner: Option<S3Common>,
132
133    #[serde(rename = "stage")]
134    pub stage: Option<String>,
135}
136
137fn default_schedule() -> u64 {
138    3600 // Default to 1 hour
139}
140
141fn default_with_s3() -> bool {
142    true
143}
144
145impl SnowflakeV2Config {
146    /// Build JDBC Properties for the Snowflake JDBC connection (no URL parameters).
147    /// Returns (`jdbc_url`, `driver_properties`).
148    /// - `driver_properties` are transformed/used by the Java runner and passed to `DriverManager::getConnection(url, props)`
149    ///
150    /// Note: This method assumes the config has been validated by `from_btreemap`.
151    pub fn build_jdbc_connection_properties(&self) -> Result<(String, Vec<(String, String)>)> {
152        let jdbc_url = self
153            .jdbc_url
154            .clone()
155            .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?;
156        let username = self
157            .username
158            .clone()
159            .ok_or(SinkError::Config(anyhow!("username is required")))?;
160
161        let mut connection_properties: Vec<(String, String)> = vec![("user".to_owned(), username)];
162
163        // auth_method is guaranteed to be Some after validation in from_btreemap
164        match self.auth_method.as_deref().unwrap() {
165            AUTH_METHOD_PASSWORD => {
166                // password is guaranteed to exist by from_btreemap validation
167                connection_properties.push(("password".to_owned(), self.password.clone().unwrap()));
168            }
169            AUTH_METHOD_KEY_PAIR_FILE => {
170                // private_key_file is guaranteed to exist by from_btreemap validation
171                connection_properties.push((
172                    "private_key_file".to_owned(),
173                    self.private_key_file.clone().unwrap(),
174                ));
175                if let Some(pwd) = self.private_key_file_pwd.clone() {
176                    connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
177                }
178            }
179            AUTH_METHOD_KEY_PAIR_OBJECT => {
180                connection_properties.push((
181                    PROP_AUTH_METHOD.to_owned(),
182                    AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
183                ));
184                // private_key_pem is guaranteed to exist by from_btreemap validation
185                connection_properties.push((
186                    "private_key_pem".to_owned(),
187                    self.private_key_pem.clone().unwrap(),
188                ));
189                if let Some(pwd) = self.private_key_file_pwd.clone() {
190                    connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
191                }
192            }
193            _ => {
194                // This should never happen since from_btreemap validates auth_method
195                unreachable!(
196                    "Invalid auth_method - should have been caught during config validation"
197                )
198            }
199        }
200
201        Ok((jdbc_url, connection_properties))
202    }
203
204    pub fn from_btreemap(properties: &BTreeMap<String, String>) -> Result<Self> {
205        let mut config =
206            serde_json::from_value::<SnowflakeV2Config>(serde_json::to_value(properties).unwrap())
207                .map_err(|e| SinkError::Config(anyhow!(e)))?;
208        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
209            return Err(SinkError::Config(anyhow!(
210                "`{}` must be {}, or {}",
211                SINK_TYPE_OPTION,
212                SINK_TYPE_APPEND_ONLY,
213                SINK_TYPE_UPSERT
214            )));
215        }
216
217        // Normalize and validate authentication method
218        let has_password = config.password.is_some();
219        let has_file = config.private_key_file.is_some();
220        let has_pem = config.private_key_pem.as_deref().is_some();
221
222        let normalized_auth_method = match config
223            .auth_method
224            .as_deref()
225            .map(|s| s.trim().to_ascii_lowercase())
226        {
227            Some(method) if method == AUTH_METHOD_PASSWORD => {
228                if !has_password {
229                    return Err(SinkError::Config(anyhow!(
230                        "auth.method=password requires `password`"
231                    )));
232                }
233                if has_file || has_pem {
234                    return Err(SinkError::Config(anyhow!(
235                        "auth.method=password must not set `private_key_file`/`private_key_pem`"
236                    )));
237                }
238                AUTH_METHOD_PASSWORD.to_owned()
239            }
240            Some(method) if method == AUTH_METHOD_KEY_PAIR_FILE => {
241                if !has_file {
242                    return Err(SinkError::Config(anyhow!(
243                        "auth.method=key_pair_file requires `private_key_file`"
244                    )));
245                }
246                if has_password {
247                    return Err(SinkError::Config(anyhow!(
248                        "auth.method=key_pair_file must not set `password`"
249                    )));
250                }
251                if has_pem {
252                    return Err(SinkError::Config(anyhow!(
253                        "auth.method=key_pair_file must not set `private_key_pem`"
254                    )));
255                }
256                AUTH_METHOD_KEY_PAIR_FILE.to_owned()
257            }
258            Some(method) if method == AUTH_METHOD_KEY_PAIR_OBJECT => {
259                if !has_pem {
260                    return Err(SinkError::Config(anyhow!(
261                        "auth.method=key_pair_object requires `private_key_pem`"
262                    )));
263                }
264                if has_password {
265                    return Err(SinkError::Config(anyhow!(
266                        "auth.method=key_pair_object must not set `password`"
267                    )));
268                }
269                AUTH_METHOD_KEY_PAIR_OBJECT.to_owned()
270            }
271            Some(other) => {
272                return Err(SinkError::Config(anyhow!(
273                    "invalid auth.method: {} (allowed: password | key_pair_file | key_pair_object)",
274                    other
275                )));
276            }
277            None => {
278                // Infer auth method from supplied fields
279                match (has_password, has_file, has_pem) {
280                    (true, false, false) => AUTH_METHOD_PASSWORD.to_owned(),
281                    (false, true, false) => AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
282                    (false, false, true) => AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
283                    (true, true, _) | (true, _, true) | (false, true, true) => {
284                        return Err(SinkError::Config(anyhow!(
285                            "ambiguous auth: multiple auth options provided; remove one or set `auth.method`"
286                        )));
287                    }
288                    _ => {
289                        return Err(SinkError::Config(anyhow!(
290                            "no authentication configured: set either `password`, or `private_key_file`, or `private_key_pem` (or provide `auth.method`)"
291                        )));
292                    }
293                }
294            }
295        };
296        config.auth_method = Some(normalized_auth_method);
297        Ok(config)
298    }
299
300    pub fn build_snowflake_task_ctx_jdbc_client(
301        &self,
302        is_append_only: bool,
303        schema: &Schema,
304        pk_indices: &Vec<usize>,
305    ) -> Result<Option<(SnowflakeTaskContext, JdbcJniClient)>> {
306        if !self.auto_schema_change && is_append_only && !self.create_table_if_not_exists {
307            // append-only + no auto schema change is not need to create a client
308            return Ok(None);
309        }
310        let target_table_name = self
311            .snowflake_target_table_name
312            .clone()
313            .ok_or(SinkError::Config(anyhow!("table.name is required")))?;
314        let database = self
315            .snowflake_database
316            .clone()
317            .ok_or(SinkError::Config(anyhow!("database is required")))?;
318        let schema_name = self
319            .snowflake_schema
320            .clone()
321            .ok_or(SinkError::Config(anyhow!("schema is required")))?;
322        let mut snowflake_task_ctx = SnowflakeTaskContext {
323            target_table_name: target_table_name.clone(),
324            database,
325            schema_name,
326            schema: schema.clone(),
327            ..Default::default()
328        };
329
330        let (jdbc_url, connection_properties) = self.build_jdbc_connection_properties()?;
331        let client = JdbcJniClient::new_with_props(jdbc_url, connection_properties)?;
332
333        if self.with_s3 {
334            let stage = self
335                .stage
336                .clone()
337                .ok_or(SinkError::Config(anyhow!("stage is required")))?;
338            snowflake_task_ctx.stage = Some(stage);
339            snowflake_task_ctx.pipe_name = Some(format!("{}_pipe", target_table_name));
340        }
341        if !is_append_only {
342            let cdc_table_name = self
343                .snowflake_cdc_table_name
344                .clone()
345                .ok_or(SinkError::Config(anyhow!(
346                    "intermediate.table.name is required"
347                )))?;
348            snowflake_task_ctx.cdc_table_name = Some(cdc_table_name.clone());
349            snowflake_task_ctx.schedule_seconds = self.snowflake_schedule_seconds;
350            snowflake_task_ctx.warehouse = Some(
351                self.snowflake_warehouse
352                    .clone()
353                    .ok_or(SinkError::Config(anyhow!("warehouse is required")))?,
354            );
355            let pk_column_names: Vec<_> = schema
356                .fields
357                .iter()
358                .enumerate()
359                .filter(|(index, _)| pk_indices.contains(index))
360                .map(|(_, field)| field.name.clone())
361                .collect();
362            if pk_column_names.is_empty() {
363                return Err(SinkError::Config(anyhow!(
364                    "Primary key columns not found. Please set the `primary_key` column in the sink properties, or ensure that the sink contains the primary key columns from the upstream."
365                )));
366            }
367            snowflake_task_ctx.pk_column_names = Some(pk_column_names);
368            snowflake_task_ctx.all_column_names = Some(
369                schema
370                    .fields
371                    .iter()
372                    .map(|field| field.name.clone())
373                    .collect(),
374            );
375            snowflake_task_ctx.task_name = Some(format!(
376                "rw_snowflake_sink_from_{cdc_table_name}_to_{target_table_name}"
377            ));
378        }
379        Ok(Some((snowflake_task_ctx, client)))
380    }
381}
382
383impl EnforceSecret for SnowflakeV2Config {
384    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
385        "username",
386        "password",
387        "jdbc.url",
388        // Key-pair authentication secrets
389        "private_key_file_pwd",
390        "private_key_pem",
391    };
392}
393
394#[derive(Clone, Debug)]
395pub struct SnowflakeV2Sink {
396    config: SnowflakeV2Config,
397    schema: Schema,
398    pk_indices: Vec<usize>,
399    is_append_only: bool,
400    param: SinkParam,
401}
402
403impl EnforceSecret for SnowflakeV2Sink {
404    fn enforce_secret<'a>(
405        prop_iter: impl Iterator<Item = &'a str>,
406    ) -> crate::sink::ConnectorResult<()> {
407        for prop in prop_iter {
408            SnowflakeV2Config::enforce_one(prop)?;
409        }
410        Ok(())
411    }
412}
413
414impl TryFrom<SinkParam> for SnowflakeV2Sink {
415    type Error = SinkError;
416
417    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
418        let schema = param.schema();
419        let config = SnowflakeV2Config::from_btreemap(&param.properties)?;
420        let is_append_only = param.sink_type.is_append_only();
421        let pk_indices = param.downstream_pk_or_empty();
422        Ok(Self {
423            config,
424            schema,
425            pk_indices,
426            is_append_only,
427            param,
428        })
429    }
430}
431
432impl Sink for SnowflakeV2Sink {
433    type LogSinker = CoordinatedLogSinker<SnowflakeSinkWriter>;
434
435    const SINK_NAME: &'static str = SNOWFLAKE_SINK_V2;
436
437    async fn validate(&self) -> Result<()> {
438        risingwave_common::license::Feature::SnowflakeSink
439            .check_available()
440            .map_err(|e| anyhow::anyhow!(e))?;
441        if let Some((snowflake_task_ctx, client)) =
442            self.config.build_snowflake_task_ctx_jdbc_client(
443                self.is_append_only,
444                &self.schema,
445                &self.pk_indices,
446            )?
447        {
448            let client = SnowflakeJniClient::new(client, snowflake_task_ctx);
449            client.execute_create_table().await?;
450        }
451
452        Ok(())
453    }
454
455    fn support_schema_change() -> bool {
456        true
457    }
458
459    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
460        SnowflakeV2Config::from_btreemap(config)?;
461        Ok(())
462    }
463
464    async fn new_log_sinker(
465        &self,
466        writer_param: crate::sink::SinkWriterParam,
467    ) -> Result<Self::LogSinker> {
468        let writer = SnowflakeSinkWriter::new(
469            self.config.clone(),
470            self.is_append_only,
471            writer_param.clone(),
472            self.param.clone(),
473        )
474        .await?;
475
476        let commit_checkpoint_interval =
477            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
478                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
479            );
480
481        CoordinatedLogSinker::new(
482            &writer_param,
483            self.param.clone(),
484            writer,
485            commit_checkpoint_interval,
486        )
487        .await
488    }
489
490    fn is_coordinated_sink(&self) -> bool {
491        true
492    }
493
494    async fn new_coordinator(
495        &self,
496        _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
497    ) -> Result<SinkCommitCoordinator> {
498        let coordinator = SnowflakeSinkCommitter::new(
499            self.config.clone(),
500            &self.schema,
501            &self.pk_indices,
502            self.is_append_only,
503        )?;
504        Ok(SinkCommitCoordinator::SinglePhase(Box::new(coordinator)))
505    }
506}
507
508pub enum SnowflakeSinkWriter {
509    S3(SnowflakeRedshiftSinkS3Writer),
510    Jdbc(SnowflakeSinkJdbcWriter),
511}
512
513impl SnowflakeSinkWriter {
514    pub async fn new(
515        config: SnowflakeV2Config,
516        is_append_only: bool,
517        writer_param: SinkWriterParam,
518        param: SinkParam,
519    ) -> Result<Self> {
520        let schema = param.schema();
521        if config.with_s3 {
522            let executor_id = writer_param.executor_id;
523            let s3_writer = SnowflakeRedshiftSinkS3Writer::new(
524                config.s3_inner.ok_or_else(|| {
525                    SinkError::Config(anyhow!(
526                        "S3 configuration is required for Snowflake S3 sink"
527                    ))
528                })?,
529                schema,
530                is_append_only,
531                executor_id,
532                config.snowflake_target_table_name,
533            )?;
534            Ok(Self::S3(s3_writer))
535        } else {
536            let jdbc_writer =
537                SnowflakeSinkJdbcWriter::new(config, is_append_only, writer_param, param).await?;
538            Ok(Self::Jdbc(jdbc_writer))
539        }
540    }
541}
542
543#[async_trait]
544impl SinkWriter for SnowflakeSinkWriter {
545    type CommitMetadata = Option<SinkMetadata>;
546
547    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
548        match self {
549            Self::S3(writer) => writer.begin_epoch(epoch),
550            Self::Jdbc(writer) => writer.begin_epoch(epoch).await,
551        }
552    }
553
554    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
555        match self {
556            Self::S3(writer) => writer.write_batch(chunk).await,
557            Self::Jdbc(writer) => writer.write_batch(chunk).await,
558        }
559    }
560
561    async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
562        match self {
563            Self::S3(writer) => {
564                writer.barrier(is_checkpoint).await?;
565            }
566            Self::Jdbc(writer) => {
567                writer.barrier(is_checkpoint).await?;
568            }
569        }
570        Ok(Some(SinkMetadata {
571            metadata: Some(sink_metadata::Metadata::Serialized(
572                risingwave_pb::connector_service::sink_metadata::SerializedMetadata {
573                    metadata: vec![],
574                },
575            )),
576        }))
577    }
578
579    async fn abort(&mut self) -> Result<()> {
580        if let Self::Jdbc(writer) = self {
581            writer.abort().await
582        } else {
583            Ok(())
584        }
585    }
586}
587
588pub struct SnowflakeSinkJdbcWriter {
589    augmented_row: AugmentedChunk,
590    jdbc_sink_writer: CoordinatedRemoteSinkWriter,
591}
592
593impl SnowflakeSinkJdbcWriter {
594    pub async fn new(
595        config: SnowflakeV2Config,
596        is_append_only: bool,
597        writer_param: SinkWriterParam,
598        mut param: SinkParam,
599    ) -> Result<Self> {
600        let metrics = SinkWriterMetrics::new(&writer_param);
601        let properties = &param.properties;
602        let column_descs = &mut param.columns;
603        let full_table_name = if is_append_only {
604            format!(
605                r#""{}"."{}"."{}""#,
606                config.snowflake_database.clone().unwrap_or_default(),
607                config.snowflake_schema.clone().unwrap_or_default(),
608                config
609                    .snowflake_target_table_name
610                    .clone()
611                    .unwrap_or_default()
612            )
613        } else {
614            let max_column_id = column_descs
615                .iter()
616                .map(|column| column.column_id.get_id())
617                .max()
618                .unwrap_or(0);
619            (*column_descs).push(ColumnDesc::named(
620                SNOWFLAKE_SINK_ROW_ID,
621                ColumnId::new(max_column_id + 1),
622                DataType::Varchar,
623            ));
624            (*column_descs).push(ColumnDesc::named(
625                SNOWFLAKE_SINK_OP,
626                ColumnId::new(max_column_id + 2),
627                DataType::Int32,
628            ));
629            format!(
630                r#""{}"."{}"."{}""#,
631                config.snowflake_database.clone().unwrap_or_default(),
632                config.snowflake_schema.clone().unwrap_or_default(),
633                config.snowflake_cdc_table_name.clone().unwrap_or_default()
634            )
635        };
636        let mut new_properties = BTreeMap::from([
637            ("table.name".to_owned(), full_table_name),
638            ("connector".to_owned(), "snowflake_v2".to_owned()),
639            (
640                "jdbc.url".to_owned(),
641                config.jdbc_url.clone().unwrap_or_default(),
642            ),
643            ("type".to_owned(), "append-only".to_owned()),
644            (
645                "primary_key".to_owned(),
646                properties.get("primary_key").cloned().unwrap_or_default(),
647            ),
648            (
649                "schema.name".to_owned(),
650                config.snowflake_schema.clone().unwrap_or_default(),
651            ),
652            (
653                "database.name".to_owned(),
654                config.snowflake_database.clone().unwrap_or_default(),
655            ),
656        ]);
657
658        // Reuse build_jdbc_connection_properties to get driver properties (auth, user, etc.)
659        let (_jdbc_url, connection_properties) = config.build_jdbc_connection_properties()?;
660        for (key, value) in connection_properties {
661            new_properties.insert(key, value);
662        }
663
664        param.properties = new_properties;
665
666        let jdbc_sink_writer =
667            CoordinatedRemoteSinkWriter::new(param.clone(), metrics.clone()).await?;
668        Ok(Self {
669            augmented_row: AugmentedChunk::new(0, is_append_only),
670            jdbc_sink_writer,
671        })
672    }
673}
674
675impl SnowflakeSinkJdbcWriter {
676    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
677        self.augmented_row.reset_epoch(epoch);
678        self.jdbc_sink_writer.begin_epoch(epoch).await?;
679        Ok(())
680    }
681
682    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
683        let chunk = self.augmented_row.augmented_chunk(chunk)?;
684        self.jdbc_sink_writer.write_batch(chunk).await?;
685        Ok(())
686    }
687
688    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
689        self.jdbc_sink_writer.barrier(is_checkpoint).await?;
690        Ok(())
691    }
692
693    async fn abort(&mut self) -> Result<()> {
694        // TODO: abort should clean up all the data written in this epoch.
695        self.jdbc_sink_writer.abort().await?;
696        Ok(())
697    }
698}
699
700#[derive(Default)]
701pub struct SnowflakeTaskContext {
702    // required for task creation
703    pub target_table_name: String,
704    pub database: String,
705    pub schema_name: String,
706    pub schema: Schema,
707
708    // only upsert
709    pub task_name: Option<String>,
710    pub cdc_table_name: Option<String>,
711    pub schedule_seconds: u64,
712    pub warehouse: Option<String>,
713    pub pk_column_names: Option<Vec<String>>,
714    pub all_column_names: Option<Vec<String>>,
715
716    // only s3 writer
717    pub stage: Option<String>,
718    pub pipe_name: Option<String>,
719}
720pub struct SnowflakeSinkCommitter {
721    client: Option<SnowflakeJniClient>,
722}
723
724impl SnowflakeSinkCommitter {
725    pub fn new(
726        config: SnowflakeV2Config,
727        schema: &Schema,
728        pk_indices: &Vec<usize>,
729        is_append_only: bool,
730    ) -> Result<Self> {
731        let client = if let Some((snowflake_task_ctx, client)) =
732            config.build_snowflake_task_ctx_jdbc_client(is_append_only, schema, pk_indices)?
733        {
734            Some(SnowflakeJniClient::new(client, snowflake_task_ctx))
735        } else {
736            None
737        };
738        Ok(Self { client })
739    }
740}
741
742#[async_trait]
743impl SinglePhaseCommitCoordinator for SnowflakeSinkCommitter {
744    async fn init(&mut self) -> Result<()> {
745        if let Some(client) = &self.client {
746            // Todo: move this to validate
747            client.execute_create_pipe().await?;
748            client.execute_create_merge_into_task().await?;
749        }
750        Ok(())
751    }
752
753    async fn commit_data(&mut self, _epoch: u64, _metadata: Vec<SinkMetadata>) -> Result<()> {
754        let client = self.client.as_mut().ok_or_else(|| {
755            SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
756        })?;
757        client.execute_flush_pipe().await
758    }
759
760    async fn commit_schema_change(
761        &mut self,
762        _epoch: u64,
763        schema_change: PbSinkSchemaChange,
764    ) -> Result<()> {
765        use risingwave_pb::stream_plan::sink_schema_change::PbOp as SinkSchemaChangeOp;
766        let schema_change_op = schema_change
767            .op
768            .ok_or_else(|| SinkError::Coordinator(anyhow!("Invalid schema change operation")))?;
769        let SinkSchemaChangeOp::AddColumns(add_columns) = schema_change_op else {
770            return Err(SinkError::Coordinator(anyhow!(
771                "Only AddColumns schema change is supported for Snowflake sink"
772            )));
773        };
774        let client = self.client.as_mut().ok_or_else(|| {
775            SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
776        })?;
777        client
778            .execute_alter_add_columns(
779                &add_columns
780                    .fields
781                    .into_iter()
782                    .map(|f| (f.name, DataType::from(f.data_type.unwrap()).to_string()))
783                    .collect_vec(),
784            )
785            .await
786    }
787}
788
789impl Drop for SnowflakeSinkCommitter {
790    fn drop(&mut self) {
791        if let Some(client) = self.client.take() {
792            tokio::spawn(async move {
793                client.execute_drop_task().await.ok();
794            });
795        }
796    }
797}
798
799pub struct SnowflakeJniClient {
800    jdbc_client: JdbcJniClient,
801    snowflake_task_context: SnowflakeTaskContext,
802}
803
804impl SnowflakeJniClient {
805    pub fn new(jdbc_client: JdbcJniClient, snowflake_task_context: SnowflakeTaskContext) -> Self {
806        Self {
807            jdbc_client,
808            snowflake_task_context,
809        }
810    }
811
812    pub async fn execute_alter_add_columns(
813        &mut self,
814        columns: &Vec<(String, String)>,
815    ) -> Result<()> {
816        self.execute_drop_task().await?;
817        if let Some(names) = self.snowflake_task_context.all_column_names.as_mut() {
818            names.extend(columns.iter().map(|(name, _)| name.clone()));
819        }
820        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
821            let alter_add_column_cdc_table_sql = build_alter_add_column_sql(
822                cdc_table_name,
823                &self.snowflake_task_context.database,
824                &self.snowflake_task_context.schema_name,
825                columns,
826            );
827            self.jdbc_client
828                .execute_sql_sync(vec![alter_add_column_cdc_table_sql])
829                .await?;
830        }
831
832        let alter_add_column_target_table_sql = build_alter_add_column_sql(
833            &self.snowflake_task_context.target_table_name,
834            &self.snowflake_task_context.database,
835            &self.snowflake_task_context.schema_name,
836            columns,
837        );
838        self.jdbc_client
839            .execute_sql_sync(vec![alter_add_column_target_table_sql])
840            .await?;
841
842        self.execute_create_merge_into_task().await?;
843        Ok(())
844    }
845
846    pub async fn execute_create_merge_into_task(&self) -> Result<()> {
847        if self.snowflake_task_context.task_name.is_some() {
848            let create_task_sql = build_create_merge_into_task_sql(&self.snowflake_task_context);
849            let start_task_sql = build_start_task_sql(&self.snowflake_task_context);
850            self.jdbc_client
851                .execute_sql_sync(vec![create_task_sql])
852                .await?;
853            self.jdbc_client
854                .execute_sql_sync(vec![start_task_sql])
855                .await?;
856        }
857        Ok(())
858    }
859
860    pub async fn execute_drop_task(&self) -> Result<()> {
861        if self.snowflake_task_context.task_name.is_some() {
862            let sql = build_drop_task_sql(&self.snowflake_task_context);
863            if let Err(e) = self.jdbc_client.execute_sql_sync(vec![sql]).await {
864                tracing::error!(
865                    "Failed to drop Snowflake sink task {:?}: {:?}",
866                    self.snowflake_task_context.task_name,
867                    e.as_report()
868                );
869            } else {
870                tracing::info!(
871                    "Snowflake sink task {:?} dropped",
872                    self.snowflake_task_context.task_name
873                );
874            }
875        }
876        Ok(())
877    }
878
879    pub async fn execute_create_table(&self) -> Result<()> {
880        // create target table
881        let create_target_table_sql = build_create_table_sql(
882            &self.snowflake_task_context.target_table_name,
883            &self.snowflake_task_context.database,
884            &self.snowflake_task_context.schema_name,
885            &self.snowflake_task_context.schema,
886            false,
887        )?;
888        self.jdbc_client
889            .execute_sql_sync(vec![create_target_table_sql])
890            .await?;
891        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
892            let create_cdc_table_sql = build_create_table_sql(
893                cdc_table_name,
894                &self.snowflake_task_context.database,
895                &self.snowflake_task_context.schema_name,
896                &self.snowflake_task_context.schema,
897                true,
898            )?;
899            self.jdbc_client
900                .execute_sql_sync(vec![create_cdc_table_sql])
901                .await?;
902        }
903        Ok(())
904    }
905
906    pub async fn execute_create_pipe(&self) -> Result<()> {
907        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
908            let table_name =
909                if let Some(table_name) = self.snowflake_task_context.cdc_table_name.as_ref() {
910                    table_name
911                } else {
912                    &self.snowflake_task_context.target_table_name
913                };
914            let create_pipe_sql = build_create_pipe_sql(
915                table_name,
916                &self.snowflake_task_context.database,
917                &self.snowflake_task_context.schema_name,
918                self.snowflake_task_context.stage.as_ref().ok_or_else(|| {
919                    SinkError::Config(anyhow!("snowflake.stage is required for S3 writer"))
920                })?,
921                pipe_name,
922                &self.snowflake_task_context.target_table_name,
923            );
924            self.jdbc_client
925                .execute_sql_sync(vec![create_pipe_sql])
926                .await?;
927        }
928        Ok(())
929    }
930
931    pub async fn execute_flush_pipe(&self) -> Result<()> {
932        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
933            let flush_pipe_sql = build_flush_pipe_sql(
934                &self.snowflake_task_context.database,
935                &self.snowflake_task_context.schema_name,
936                pipe_name,
937            );
938            self.jdbc_client
939                .execute_sql_sync(vec![flush_pipe_sql])
940                .await?;
941        }
942        Ok(())
943    }
944}
945
946fn build_create_table_sql(
947    table_name: &str,
948    database: &str,
949    schema_name: &str,
950    schema: &Schema,
951    need_op_and_row_id: bool,
952) -> Result<String> {
953    let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema_name, table_name);
954    let mut columns: Vec<String> = schema
955        .fields
956        .iter()
957        .map(|field| {
958            let data_type = convert_snowflake_data_type(&field.data_type)?;
959            Ok(format!(r#""{}" {}"#, field.name, data_type))
960        })
961        .collect::<Result<Vec<String>>>()?;
962    if need_op_and_row_id {
963        columns.push(format!(r#""{}" STRING"#, SNOWFLAKE_SINK_ROW_ID));
964        columns.push(format!(r#""{}" INT"#, SNOWFLAKE_SINK_OP));
965    }
966    let columns_str = columns.join(", ");
967    Ok(format!(
968        "CREATE TABLE IF NOT EXISTS {} ({}) ENABLE_SCHEMA_EVOLUTION  = true",
969        full_table_name, columns_str
970    ))
971}
972
973fn convert_snowflake_data_type(data_type: &DataType) -> Result<String> {
974    let data_type = match data_type {
975        DataType::Int16 => "SMALLINT".to_owned(),
976        DataType::Int32 => "INTEGER".to_owned(),
977        DataType::Int64 => "BIGINT".to_owned(),
978        DataType::Float32 => "FLOAT4".to_owned(),
979        DataType::Float64 => "FLOAT8".to_owned(),
980        DataType::Boolean => "BOOLEAN".to_owned(),
981        DataType::Varchar => "STRING".to_owned(),
982        DataType::Date => "DATE".to_owned(),
983        DataType::Timestamp => "TIMESTAMP".to_owned(),
984        DataType::Timestamptz => "TIMESTAMP_TZ".to_owned(),
985        DataType::Jsonb => "STRING".to_owned(),
986        DataType::Decimal => "DECIMAL".to_owned(),
987        DataType::Bytea => "BINARY".to_owned(),
988        DataType::Time => "TIME".to_owned(),
989        _ => {
990            return Err(SinkError::Config(anyhow!(
991                "Dont support auto create table for datatype: {}",
992                data_type
993            )));
994        }
995    };
996    Ok(data_type)
997}
998
999fn build_create_pipe_sql(
1000    table_name: &str,
1001    database: &str,
1002    schema: &str,
1003    stage: &str,
1004    pipe_name: &str,
1005    target_table_name: &str,
1006) -> String {
1007    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
1008    let stage = format!(
1009        r#""{}"."{}"."{}"/{}"#,
1010        database, schema, stage, target_table_name
1011    );
1012    let table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
1013    format!(
1014        "CREATE OR REPLACE PIPE {} AUTO_INGEST = FALSE AS COPY INTO {} FROM @{} MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE FILE_FORMAT = (type = 'JSON');",
1015        pipe_name, table_name, stage
1016    )
1017}
1018
1019fn build_flush_pipe_sql(database: &str, schema: &str, pipe_name: &str) -> String {
1020    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
1021    format!("ALTER PIPE {} REFRESH;", pipe_name,)
1022}
1023
1024fn build_alter_add_column_sql(
1025    table_name: &str,
1026    database: &str,
1027    schema: &str,
1028    columns: &Vec<(String, String)>,
1029) -> String {
1030    let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
1031    jdbc_jni_client::build_alter_add_column_sql(&full_table_name, columns, true)
1032}
1033
1034fn build_start_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1035    let SnowflakeTaskContext {
1036        task_name,
1037        database,
1038        schema_name: schema,
1039        ..
1040    } = snowflake_task_context;
1041    let full_task_name = format!(
1042        r#""{}"."{}"."{}""#,
1043        database,
1044        schema,
1045        task_name.as_ref().unwrap()
1046    );
1047    format!("ALTER TASK {} RESUME", full_task_name)
1048}
1049
1050fn build_drop_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1051    let SnowflakeTaskContext {
1052        task_name,
1053        database,
1054        schema_name: schema,
1055        ..
1056    } = snowflake_task_context;
1057    let full_task_name = format!(
1058        r#""{}"."{}"."{}""#,
1059        database,
1060        schema,
1061        task_name.as_ref().unwrap()
1062    );
1063    format!("DROP TASK IF EXISTS {}", full_task_name)
1064}
1065
1066fn build_create_merge_into_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1067    let SnowflakeTaskContext {
1068        task_name,
1069        cdc_table_name,
1070        target_table_name,
1071        schedule_seconds,
1072        warehouse,
1073        pk_column_names,
1074        all_column_names,
1075        database,
1076        schema_name,
1077        ..
1078    } = snowflake_task_context;
1079    let full_task_name = format!(
1080        r#""{}"."{}"."{}""#,
1081        database,
1082        schema_name,
1083        task_name.as_ref().unwrap()
1084    );
1085    let full_cdc_table_name = format!(
1086        r#""{}"."{}"."{}""#,
1087        database,
1088        schema_name,
1089        cdc_table_name.as_ref().unwrap()
1090    );
1091    let full_target_table_name = format!(
1092        r#""{}"."{}"."{}""#,
1093        database, schema_name, target_table_name
1094    );
1095
1096    let pk_names_str = pk_column_names
1097        .as_ref()
1098        .unwrap()
1099        .iter()
1100        .map(|name| format!(r#""{}""#, name))
1101        .collect::<Vec<String>>()
1102        .join(", ");
1103    let pk_names_eq_str = pk_column_names
1104        .as_ref()
1105        .unwrap()
1106        .iter()
1107        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1108        .collect::<Vec<String>>()
1109        .join(" AND ");
1110    let all_column_names_set_str = all_column_names
1111        .as_ref()
1112        .unwrap()
1113        .iter()
1114        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1115        .collect::<Vec<String>>()
1116        .join(", ");
1117    let all_column_names_str = all_column_names
1118        .as_ref()
1119        .unwrap()
1120        .iter()
1121        .map(|name| format!(r#""{}""#, name))
1122        .collect::<Vec<String>>()
1123        .join(", ");
1124    let all_column_names_insert_str = all_column_names
1125        .as_ref()
1126        .unwrap()
1127        .iter()
1128        .map(|name| format!(r#"source."{}""#, name))
1129        .collect::<Vec<String>>()
1130        .join(", ");
1131
1132    format!(
1133        r#"CREATE OR REPLACE TASK {task_name}
1134WAREHOUSE = {warehouse}
1135SCHEDULE = '{schedule_seconds} SECONDS'
1136AS
1137BEGIN
1138    LET max_row_id STRING;
1139
1140    SELECT COALESCE(MAX("{snowflake_sink_row_id}"), '0') INTO :max_row_id
1141    FROM {cdc_table_name};
1142
1143    MERGE INTO {target_table_name} AS target
1144    USING (
1145        SELECT *
1146        FROM (
1147            SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_names_str} ORDER BY "{snowflake_sink_row_id}" DESC) AS dedupe_id
1148            FROM {cdc_table_name}
1149            WHERE "{snowflake_sink_row_id}" <= :max_row_id
1150        ) AS subquery
1151        WHERE dedupe_id = 1
1152    ) AS source
1153    ON {pk_names_eq_str}
1154    WHEN MATCHED AND source."{snowflake_sink_op}" IN (2, 4) THEN DELETE
1155    WHEN MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN UPDATE SET {all_column_names_set_str}
1156    WHEN NOT MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN INSERT ({all_column_names_str}) VALUES ({all_column_names_insert_str});
1157
1158    DELETE FROM {cdc_table_name}
1159    WHERE "{snowflake_sink_row_id}" <= :max_row_id;
1160END;"#,
1161        task_name = full_task_name,
1162        warehouse = warehouse.as_ref().unwrap(),
1163        schedule_seconds = schedule_seconds,
1164        cdc_table_name = full_cdc_table_name,
1165        target_table_name = full_target_table_name,
1166        pk_names_str = pk_names_str,
1167        pk_names_eq_str = pk_names_eq_str,
1168        all_column_names_set_str = all_column_names_set_str,
1169        all_column_names_str = all_column_names_str,
1170        all_column_names_insert_str = all_column_names_insert_str,
1171        snowflake_sink_row_id = SNOWFLAKE_SINK_ROW_ID,
1172        snowflake_sink_op = SNOWFLAKE_SINK_OP,
1173    )
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178    use std::collections::BTreeMap;
1179
1180    use super::*;
1181    use crate::sink::jdbc_jni_client::normalize_sql;
1182
1183    fn base_properties() -> BTreeMap<String, String> {
1184        BTreeMap::from([
1185            ("type".to_owned(), "append-only".to_owned()),
1186            ("jdbc.url".to_owned(), "jdbc:snowflake://account".to_owned()),
1187            ("username".to_owned(), "RW_USER".to_owned()),
1188        ])
1189    }
1190
1191    #[test]
1192    fn test_build_jdbc_props_password() {
1193        let mut props = base_properties();
1194        props.insert("password".to_owned(), "secret".to_owned());
1195        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1196        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1197        assert_eq!(url, "jdbc:snowflake://account");
1198        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1199        assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1200        assert_eq!(map.get("password"), Some(&"secret".to_owned()));
1201        assert!(!map.contains_key("authenticator"));
1202    }
1203
1204    #[test]
1205    fn test_build_jdbc_props_key_pair_file() {
1206        let mut props = base_properties();
1207        props.insert(
1208            "auth.method".to_owned(),
1209            AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
1210        );
1211        props.insert("private_key_file".to_owned(), "/tmp/rsa_key.p8".to_owned());
1212        props.insert("private_key_file_pwd".to_owned(), "dummy".to_owned());
1213        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1214        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1215        assert_eq!(url, "jdbc:snowflake://account");
1216        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1217        assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1218        assert_eq!(
1219            map.get("private_key_file"),
1220            Some(&"/tmp/rsa_key.p8".to_owned())
1221        );
1222        assert_eq!(map.get("private_key_file_pwd"), Some(&"dummy".to_owned()));
1223    }
1224
1225    #[test]
1226    fn test_build_jdbc_props_key_pair_object() {
1227        let mut props = base_properties();
1228        props.insert(
1229            "auth.method".to_owned(),
1230            AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
1231        );
1232        props.insert(
1233            "private_key_pem".to_owned(),
1234            "-----BEGIN PRIVATE KEY-----
1235...
1236-----END PRIVATE KEY-----"
1237                .to_owned(),
1238        );
1239        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1240        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1241        assert_eq!(url, "jdbc:snowflake://account");
1242        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1243        assert_eq!(
1244            map.get("private_key_pem"),
1245            Some(
1246                &"-----BEGIN PRIVATE KEY-----
1247...
1248-----END PRIVATE KEY-----"
1249                    .to_owned()
1250            )
1251        );
1252        assert!(!map.contains_key("private_key_file"));
1253    }
1254
1255    #[test]
1256    fn test_snowflake_sink_commit_coordinator() {
1257        let snowflake_task_context = SnowflakeTaskContext {
1258            task_name: Some("test_task".to_owned()),
1259            cdc_table_name: Some("test_cdc_table".to_owned()),
1260            target_table_name: "test_target_table".to_owned(),
1261            schedule_seconds: 3600,
1262            warehouse: Some("test_warehouse".to_owned()),
1263            pk_column_names: Some(vec!["v1".to_owned()]),
1264            all_column_names: Some(vec!["v1".to_owned(), "v2".to_owned()]),
1265            database: "test_db".to_owned(),
1266            schema_name: "test_schema".to_owned(),
1267            schema: Schema { fields: vec![] },
1268            stage: None,
1269            pipe_name: None,
1270        };
1271        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1272        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task"
1273WAREHOUSE = test_warehouse
1274SCHEDULE = '3600 SECONDS'
1275AS
1276BEGIN
1277    LET max_row_id STRING;
1278
1279    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1280    FROM "test_db"."test_schema"."test_cdc_table";
1281
1282    MERGE INTO "test_db"."test_schema"."test_target_table" AS target
1283    USING (
1284        SELECT *
1285        FROM (
1286            SELECT *, ROW_NUMBER() OVER (PARTITION BY "v1" ORDER BY "__row_id" DESC) AS dedupe_id
1287            FROM "test_db"."test_schema"."test_cdc_table"
1288            WHERE "__row_id" <= :max_row_id
1289        ) AS subquery
1290        WHERE dedupe_id = 1
1291    ) AS source
1292    ON target."v1" = source."v1"
1293    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1294    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."v1" = source."v1", target."v2" = source."v2"
1295    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("v1", "v2") VALUES (source."v1", source."v2");
1296
1297    DELETE FROM "test_db"."test_schema"."test_cdc_table"
1298    WHERE "__row_id" <= :max_row_id;
1299END;"#;
1300        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1301    }
1302
1303    #[test]
1304    fn test_snowflake_sink_commit_coordinator_multi_pk() {
1305        let snowflake_task_context = SnowflakeTaskContext {
1306            task_name: Some("test_task_multi_pk".to_owned()),
1307            cdc_table_name: Some("cdc_multi_pk".to_owned()),
1308            target_table_name: "target_multi_pk".to_owned(),
1309            schedule_seconds: 300,
1310            warehouse: Some("multi_pk_warehouse".to_owned()),
1311            pk_column_names: Some(vec!["id1".to_owned(), "id2".to_owned()]),
1312            all_column_names: Some(vec!["id1".to_owned(), "id2".to_owned(), "val".to_owned()]),
1313            database: "test_db".to_owned(),
1314            schema_name: "test_schema".to_owned(),
1315            schema: Schema { fields: vec![] },
1316            stage: None,
1317            pipe_name: None,
1318        };
1319        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1320        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task_multi_pk"
1321WAREHOUSE = multi_pk_warehouse
1322SCHEDULE = '300 SECONDS'
1323AS
1324BEGIN
1325    LET max_row_id STRING;
1326
1327    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1328    FROM "test_db"."test_schema"."cdc_multi_pk";
1329
1330    MERGE INTO "test_db"."test_schema"."target_multi_pk" AS target
1331    USING (
1332        SELECT *
1333        FROM (
1334            SELECT *, ROW_NUMBER() OVER (PARTITION BY "id1", "id2" ORDER BY "__row_id" DESC) AS dedupe_id
1335            FROM "test_db"."test_schema"."cdc_multi_pk"
1336            WHERE "__row_id" <= :max_row_id
1337        ) AS subquery
1338        WHERE dedupe_id = 1
1339    ) AS source
1340    ON target."id1" = source."id1" AND target."id2" = source."id2"
1341    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1342    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."id1" = source."id1", target."id2" = source."id2", target."val" = source."val"
1343    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("id1", "id2", "val") VALUES (source."id1", source."id2", source."val");
1344
1345    DELETE FROM "test_db"."test_schema"."cdc_multi_pk"
1346    WHERE "__row_id" <= :max_row_id;
1347END;"#;
1348        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1349    }
1350}