risingwave_connector/sink/snowflake_redshift/
snowflake.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::num::NonZeroU64;
16use std::collections::BTreeMap;
17
18use anyhow::anyhow;
19use itertools::Itertools;
20use phf::{Set, phf_set};
21use risingwave_common::array::StreamChunk;
22use risingwave_common::catalog::{ColumnDesc, ColumnId, Schema};
23use risingwave_common::types::DataType;
24use risingwave_pb::connector_service::{SinkMetadata, sink_metadata};
25use risingwave_pb::stream_plan::PbSinkSchemaChange;
26use serde::Deserialize;
27use serde_with::{DisplayFromStr, serde_as};
28use thiserror_ext::AsReport;
29use tokio::sync::mpsc::UnboundedSender;
30use tonic::async_trait;
31use with_options::WithOptions;
32
33use crate::connector_common::IcebergSinkCompactionUpdate;
34use crate::enforce_secret::EnforceSecret;
35use crate::sink::coordinate::CoordinatedLogSinker;
36use crate::sink::decouple_checkpoint_log_sink::default_commit_checkpoint_interval;
37use crate::sink::file_sink::s3::S3Common;
38use crate::sink::jdbc_jni_client::{self, JdbcJniClient};
39use crate::sink::remote::CoordinatedRemoteSinkWriter;
40use crate::sink::snowflake_redshift::{AugmentedChunk, SnowflakeRedshiftSinkS3Writer};
41use crate::sink::writer::SinkWriter;
42use crate::sink::{
43    Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT,
44    SinglePhaseCommitCoordinator, Sink, SinkCommitCoordinator, SinkError, SinkParam,
45    SinkWriterMetrics, SinkWriterParam,
46};
47
48pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2";
49pub const SNOWFLAKE_SINK_ROW_ID: &str = "__row_id";
50pub const SNOWFLAKE_SINK_OP: &str = "__op";
51
52const AUTH_METHOD_PASSWORD: &str = "password";
53const AUTH_METHOD_KEY_PAIR_FILE: &str = "key_pair_file";
54const AUTH_METHOD_KEY_PAIR_OBJECT: &str = "key_pair_object";
55const PROP_AUTH_METHOD: &str = "auth.method";
56
57#[serde_as]
58#[derive(Debug, Clone, Deserialize, WithOptions)]
59pub struct SnowflakeV2Config {
60    #[serde(rename = "type")]
61    pub r#type: String,
62
63    #[serde(rename = "intermediate.table.name")]
64    pub snowflake_cdc_table_name: Option<String>,
65
66    #[serde(rename = "table.name")]
67    pub snowflake_target_table_name: Option<String>,
68
69    #[serde(rename = "database")]
70    pub snowflake_database: Option<String>,
71
72    #[serde(rename = "schema")]
73    pub snowflake_schema: Option<String>,
74
75    #[serde(default = "default_schedule")]
76    #[serde(rename = "write.target.interval.seconds")]
77    #[serde_as(as = "DisplayFromStr")]
78    pub snowflake_schedule_seconds: u64,
79
80    #[serde(rename = "warehouse")]
81    pub snowflake_warehouse: Option<String>,
82
83    #[serde(rename = "jdbc.url")]
84    pub jdbc_url: Option<String>,
85
86    #[serde(rename = "username")]
87    pub username: Option<String>,
88
89    #[serde(rename = "password")]
90    pub password: Option<String>,
91
92    // Authentication method control (password | key_pair_file | key_pair_object)
93    #[serde(rename = "auth.method")]
94    pub auth_method: Option<String>,
95
96    // Key-pair authentication via connection Properties (Option 2: file-based)
97    #[serde(rename = "private_key_file")]
98    pub private_key_file: Option<String>,
99
100    #[serde(rename = "private_key_file_pwd")]
101    pub private_key_file_pwd: Option<String>,
102
103    // Key-pair authentication via connection Properties (Option 1: object-based, PEM content)
104    #[serde(rename = "private_key_pem")]
105    pub private_key_pem: Option<String>,
106
107    /// Commit every n(>0) checkpoints, default is 10.
108    #[serde(default = "default_commit_checkpoint_interval")]
109    #[serde_as(as = "DisplayFromStr")]
110    #[with_option(allow_alter_on_fly)]
111    pub commit_checkpoint_interval: u64,
112
113    /// Enable auto schema change for upsert sink.
114    /// If enabled, the sink will automatically alter the target table to add new columns.
115    #[serde(default)]
116    #[serde(rename = "auto.schema.change")]
117    #[serde_as(as = "DisplayFromStr")]
118    pub auto_schema_change: bool,
119
120    #[serde(default)]
121    #[serde(rename = "create_table_if_not_exists")]
122    #[serde_as(as = "DisplayFromStr")]
123    pub create_table_if_not_exists: bool,
124
125    #[serde(default = "default_with_s3")]
126    #[serde(rename = "with_s3")]
127    #[serde_as(as = "DisplayFromStr")]
128    pub with_s3: bool,
129
130    #[serde(flatten)]
131    pub s3_inner: Option<S3Common>,
132
133    #[serde(rename = "stage")]
134    pub stage: Option<String>,
135}
136
137fn default_schedule() -> u64 {
138    3600 // Default to 1 hour
139}
140
141fn default_with_s3() -> bool {
142    true
143}
144
145impl SnowflakeV2Config {
146    /// Build JDBC Properties for the Snowflake JDBC connection (no URL parameters).
147    /// Returns (`jdbc_url`, `driver_properties`).
148    /// - `driver_properties` are transformed/used by the Java runner and passed to `DriverManager::getConnection(url, props)`
149    ///
150    /// Note: This method assumes the config has been validated by `from_btreemap`.
151    pub fn build_jdbc_connection_properties(&self) -> Result<(String, Vec<(String, String)>)> {
152        let jdbc_url = self
153            .jdbc_url
154            .clone()
155            .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?;
156        let username = self
157            .username
158            .clone()
159            .ok_or(SinkError::Config(anyhow!("username is required")))?;
160
161        let mut connection_properties: Vec<(String, String)> = vec![("user".to_owned(), username)];
162
163        // auth_method is guaranteed to be Some after validation in from_btreemap
164        match self.auth_method.as_deref().unwrap() {
165            AUTH_METHOD_PASSWORD => {
166                // password is guaranteed to exist by from_btreemap validation
167                connection_properties.push(("password".to_owned(), self.password.clone().unwrap()));
168            }
169            AUTH_METHOD_KEY_PAIR_FILE => {
170                // private_key_file is guaranteed to exist by from_btreemap validation
171                connection_properties.push((
172                    "private_key_file".to_owned(),
173                    self.private_key_file.clone().unwrap(),
174                ));
175                if let Some(pwd) = self.private_key_file_pwd.clone() {
176                    connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
177                }
178            }
179            AUTH_METHOD_KEY_PAIR_OBJECT => {
180                connection_properties.push((
181                    PROP_AUTH_METHOD.to_owned(),
182                    AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
183                ));
184                // private_key_pem is guaranteed to exist by from_btreemap validation
185                connection_properties.push((
186                    "private_key_pem".to_owned(),
187                    self.private_key_pem.clone().unwrap(),
188                ));
189                if let Some(pwd) = self.private_key_file_pwd.clone() {
190                    connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
191                }
192            }
193            _ => {
194                // This should never happen since from_btreemap validates auth_method
195                unreachable!(
196                    "Invalid auth_method - should have been caught during config validation"
197                )
198            }
199        }
200
201        Ok((jdbc_url, connection_properties))
202    }
203
204    pub fn from_btreemap(properties: &BTreeMap<String, String>) -> Result<Self> {
205        let mut config =
206            serde_json::from_value::<SnowflakeV2Config>(serde_json::to_value(properties).unwrap())
207                .map_err(|e| SinkError::Config(anyhow!(e)))?;
208        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
209            return Err(SinkError::Config(anyhow!(
210                "`{}` must be {}, or {}",
211                SINK_TYPE_OPTION,
212                SINK_TYPE_APPEND_ONLY,
213                SINK_TYPE_UPSERT
214            )));
215        }
216
217        // Normalize and validate authentication method
218        let has_password = config.password.is_some();
219        let has_file = config.private_key_file.is_some();
220        let has_pem = config.private_key_pem.as_deref().is_some();
221
222        let normalized_auth_method = match config
223            .auth_method
224            .as_deref()
225            .map(|s| s.trim().to_ascii_lowercase())
226        {
227            Some(method) if method == AUTH_METHOD_PASSWORD => {
228                if !has_password {
229                    return Err(SinkError::Config(anyhow!(
230                        "auth.method=password requires `password`"
231                    )));
232                }
233                if has_file || has_pem {
234                    return Err(SinkError::Config(anyhow!(
235                        "auth.method=password must not set `private_key_file`/`private_key_pem`"
236                    )));
237                }
238                AUTH_METHOD_PASSWORD.to_owned()
239            }
240            Some(method) if method == AUTH_METHOD_KEY_PAIR_FILE => {
241                if !has_file {
242                    return Err(SinkError::Config(anyhow!(
243                        "auth.method=key_pair_file requires `private_key_file`"
244                    )));
245                }
246                if has_password {
247                    return Err(SinkError::Config(anyhow!(
248                        "auth.method=key_pair_file must not set `password`"
249                    )));
250                }
251                if has_pem {
252                    return Err(SinkError::Config(anyhow!(
253                        "auth.method=key_pair_file must not set `private_key_pem`"
254                    )));
255                }
256                AUTH_METHOD_KEY_PAIR_FILE.to_owned()
257            }
258            Some(method) if method == AUTH_METHOD_KEY_PAIR_OBJECT => {
259                if !has_pem {
260                    return Err(SinkError::Config(anyhow!(
261                        "auth.method=key_pair_object requires `private_key_pem`"
262                    )));
263                }
264                if has_password {
265                    return Err(SinkError::Config(anyhow!(
266                        "auth.method=key_pair_object must not set `password`"
267                    )));
268                }
269                AUTH_METHOD_KEY_PAIR_OBJECT.to_owned()
270            }
271            Some(other) => {
272                return Err(SinkError::Config(anyhow!(
273                    "invalid auth.method: {} (allowed: password | key_pair_file | key_pair_object)",
274                    other
275                )));
276            }
277            None => {
278                // Infer auth method from supplied fields
279                match (has_password, has_file, has_pem) {
280                    (true, false, false) => AUTH_METHOD_PASSWORD.to_owned(),
281                    (false, true, false) => AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
282                    (false, false, true) => AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
283                    (true, true, _) | (true, _, true) | (false, true, true) => {
284                        return Err(SinkError::Config(anyhow!(
285                            "ambiguous auth: multiple auth options provided; remove one or set `auth.method`"
286                        )));
287                    }
288                    _ => {
289                        return Err(SinkError::Config(anyhow!(
290                            "no authentication configured: set either `password`, or `private_key_file`, or `private_key_pem` (or provide `auth.method`)"
291                        )));
292                    }
293                }
294            }
295        };
296        config.auth_method = Some(normalized_auth_method);
297        Ok(config)
298    }
299
300    pub fn build_snowflake_task_ctx_jdbc_client(
301        &self,
302        is_append_only: bool,
303        schema: &Schema,
304        pk_indices: &Vec<usize>,
305    ) -> Result<Option<(SnowflakeTaskContext, JdbcJniClient)>> {
306        if !self.auto_schema_change && is_append_only && !self.create_table_if_not_exists {
307            // append-only + no auto schema change is not need to create a client
308            return Ok(None);
309        }
310        let target_table_name = self
311            .snowflake_target_table_name
312            .clone()
313            .ok_or(SinkError::Config(anyhow!("table.name is required")))?;
314        let database = self
315            .snowflake_database
316            .clone()
317            .ok_or(SinkError::Config(anyhow!("database is required")))?;
318        let schema_name = self
319            .snowflake_schema
320            .clone()
321            .ok_or(SinkError::Config(anyhow!("schema is required")))?;
322        let mut snowflake_task_ctx = SnowflakeTaskContext {
323            target_table_name: target_table_name.clone(),
324            database,
325            schema_name,
326            schema: schema.clone(),
327            ..Default::default()
328        };
329
330        let (jdbc_url, connection_properties) = self.build_jdbc_connection_properties()?;
331        let client = JdbcJniClient::new_with_props(jdbc_url, connection_properties)?;
332
333        if self.with_s3 {
334            let stage = self
335                .stage
336                .clone()
337                .ok_or(SinkError::Config(anyhow!("stage is required")))?;
338            snowflake_task_ctx.stage = Some(stage);
339            snowflake_task_ctx.pipe_name = Some(format!("{}_pipe", target_table_name));
340        }
341        if !is_append_only {
342            let cdc_table_name = self
343                .snowflake_cdc_table_name
344                .clone()
345                .ok_or(SinkError::Config(anyhow!(
346                    "intermediate.table.name is required"
347                )))?;
348            snowflake_task_ctx.cdc_table_name = Some(cdc_table_name.clone());
349            snowflake_task_ctx.schedule_seconds = self.snowflake_schedule_seconds;
350            snowflake_task_ctx.warehouse = Some(
351                self.snowflake_warehouse
352                    .clone()
353                    .ok_or(SinkError::Config(anyhow!("warehouse is required")))?,
354            );
355            let pk_column_names: Vec<_> = schema
356                .fields
357                .iter()
358                .enumerate()
359                .filter(|(index, _)| pk_indices.contains(index))
360                .map(|(_, field)| field.name.clone())
361                .collect();
362            if pk_column_names.is_empty() {
363                return Err(SinkError::Config(anyhow!(
364                    "Primary key columns not found. Please set the `primary_key` column in the sink properties, or ensure that the sink contains the primary key columns from the upstream."
365                )));
366            }
367            snowflake_task_ctx.pk_column_names = Some(pk_column_names);
368            snowflake_task_ctx.all_column_names = Some(
369                schema
370                    .fields
371                    .iter()
372                    .map(|field| field.name.clone())
373                    .collect(),
374            );
375            snowflake_task_ctx.task_name = Some(format!(
376                "rw_snowflake_sink_from_{cdc_table_name}_to_{target_table_name}"
377            ));
378        }
379        Ok(Some((snowflake_task_ctx, client)))
380    }
381}
382
383impl EnforceSecret for SnowflakeV2Config {
384    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
385        "username",
386        "password",
387        "jdbc.url",
388        // Key-pair authentication secrets
389        "private_key_file_pwd",
390        "private_key_pem",
391    };
392}
393
394#[derive(Clone, Debug)]
395pub struct SnowflakeV2Sink {
396    config: SnowflakeV2Config,
397    schema: Schema,
398    pk_indices: Vec<usize>,
399    is_append_only: bool,
400    param: SinkParam,
401}
402
403impl EnforceSecret for SnowflakeV2Sink {
404    fn enforce_secret<'a>(
405        prop_iter: impl Iterator<Item = &'a str>,
406    ) -> crate::sink::ConnectorResult<()> {
407        for prop in prop_iter {
408            SnowflakeV2Config::enforce_one(prop)?;
409        }
410        Ok(())
411    }
412}
413
414impl TryFrom<SinkParam> for SnowflakeV2Sink {
415    type Error = SinkError;
416
417    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
418        let schema = param.schema();
419        let config = SnowflakeV2Config::from_btreemap(&param.properties)?;
420        let is_append_only = param.sink_type.is_append_only();
421        let pk_indices = param.downstream_pk_or_empty();
422        Ok(Self {
423            config,
424            schema,
425            pk_indices,
426            is_append_only,
427            param,
428        })
429    }
430}
431
432impl Sink for SnowflakeV2Sink {
433    type LogSinker = CoordinatedLogSinker<SnowflakeSinkWriter>;
434
435    const SINK_NAME: &'static str = SNOWFLAKE_SINK_V2;
436
437    async fn validate(&self) -> Result<()> {
438        risingwave_common::license::Feature::SnowflakeSink
439            .check_available()
440            .map_err(|e| anyhow::anyhow!(e))?;
441        if let Some((snowflake_task_ctx, client)) =
442            self.config.build_snowflake_task_ctx_jdbc_client(
443                self.is_append_only,
444                &self.schema,
445                &self.pk_indices,
446            )?
447        {
448            let client = SnowflakeJniClient::new(client, snowflake_task_ctx);
449            client.execute_create_table().await?;
450        }
451
452        Ok(())
453    }
454
455    fn support_schema_change() -> bool {
456        true
457    }
458
459    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
460        SnowflakeV2Config::from_btreemap(config)?;
461        Ok(())
462    }
463
464    async fn new_log_sinker(
465        &self,
466        writer_param: crate::sink::SinkWriterParam,
467    ) -> Result<Self::LogSinker> {
468        let writer = SnowflakeSinkWriter::new(
469            self.config.clone(),
470            self.is_append_only,
471            writer_param.clone(),
472            self.param.clone(),
473        )
474        .await?;
475
476        let commit_checkpoint_interval =
477            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
478                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
479            );
480
481        CoordinatedLogSinker::new(
482            &writer_param,
483            self.param.clone(),
484            writer,
485            commit_checkpoint_interval,
486        )
487        .await
488    }
489
490    fn is_coordinated_sink(&self) -> bool {
491        true
492    }
493
494    async fn new_coordinator(
495        &self,
496        _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
497    ) -> Result<SinkCommitCoordinator> {
498        let coordinator = SnowflakeSinkCommitter::new(
499            self.config.clone(),
500            &self.schema,
501            &self.pk_indices,
502            self.is_append_only,
503        )?;
504        Ok(SinkCommitCoordinator::SinglePhase(Box::new(coordinator)))
505    }
506}
507
508pub enum SnowflakeSinkWriter {
509    S3(SnowflakeRedshiftSinkS3Writer),
510    Jdbc(SnowflakeSinkJdbcWriter),
511}
512
513impl SnowflakeSinkWriter {
514    pub async fn new(
515        config: SnowflakeV2Config,
516        is_append_only: bool,
517        writer_param: SinkWriterParam,
518        param: SinkParam,
519    ) -> Result<Self> {
520        let schema = param.schema();
521        if config.with_s3 {
522            let executor_id = writer_param.executor_id;
523            let s3_writer = SnowflakeRedshiftSinkS3Writer::new(
524                config.s3_inner.ok_or_else(|| {
525                    SinkError::Config(anyhow!(
526                        "S3 configuration is required for Snowflake S3 sink"
527                    ))
528                })?,
529                schema,
530                is_append_only,
531                executor_id,
532                config.snowflake_target_table_name,
533            )?;
534            Ok(Self::S3(s3_writer))
535        } else {
536            let jdbc_writer =
537                SnowflakeSinkJdbcWriter::new(config, is_append_only, writer_param, param).await?;
538            Ok(Self::Jdbc(jdbc_writer))
539        }
540    }
541}
542
543#[async_trait]
544impl SinkWriter for SnowflakeSinkWriter {
545    type CommitMetadata = Option<SinkMetadata>;
546
547    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
548        match self {
549            Self::S3(writer) => writer.begin_epoch(epoch),
550            Self::Jdbc(writer) => writer.begin_epoch(epoch).await,
551        }
552    }
553
554    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
555        match self {
556            Self::S3(writer) => writer.write_batch(chunk).await,
557            Self::Jdbc(writer) => writer.write_batch(chunk).await,
558        }
559    }
560
561    async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
562        match self {
563            Self::S3(writer) => {
564                writer.barrier(is_checkpoint).await?;
565            }
566            Self::Jdbc(writer) => {
567                writer.barrier(is_checkpoint).await?;
568            }
569        }
570        Ok(Some(SinkMetadata {
571            metadata: Some(sink_metadata::Metadata::Serialized(
572                risingwave_pb::connector_service::sink_metadata::SerializedMetadata {
573                    metadata: vec![],
574                },
575            )),
576        }))
577    }
578
579    async fn abort(&mut self) -> Result<()> {
580        if let Self::Jdbc(writer) = self {
581            writer.abort().await
582        } else {
583            Ok(())
584        }
585    }
586}
587
588pub struct SnowflakeSinkJdbcWriter {
589    augmented_row: AugmentedChunk,
590    jdbc_sink_writer: CoordinatedRemoteSinkWriter,
591}
592
593impl SnowflakeSinkJdbcWriter {
594    pub async fn new(
595        config: SnowflakeV2Config,
596        is_append_only: bool,
597        writer_param: SinkWriterParam,
598        mut param: SinkParam,
599    ) -> Result<Self> {
600        let metrics = SinkWriterMetrics::new(&writer_param);
601        let properties = &param.properties;
602        let column_descs = &mut param.columns;
603        let full_table_name = if is_append_only {
604            format!(
605                r#""{}"."{}"."{}""#,
606                config.snowflake_database.clone().unwrap_or_default(),
607                config.snowflake_schema.clone().unwrap_or_default(),
608                config
609                    .snowflake_target_table_name
610                    .clone()
611                    .unwrap_or_default()
612            )
613        } else {
614            let max_column_id = column_descs
615                .iter()
616                .map(|column| column.column_id.get_id())
617                .max()
618                .unwrap_or(0);
619            (*column_descs).push(ColumnDesc::named(
620                SNOWFLAKE_SINK_ROW_ID,
621                ColumnId::new(max_column_id + 1),
622                DataType::Varchar,
623            ));
624            (*column_descs).push(ColumnDesc::named(
625                SNOWFLAKE_SINK_OP,
626                ColumnId::new(max_column_id + 2),
627                DataType::Int32,
628            ));
629            format!(
630                r#""{}"."{}"."{}""#,
631                config.snowflake_database.clone().unwrap_or_default(),
632                config.snowflake_schema.clone().unwrap_or_default(),
633                config.snowflake_cdc_table_name.clone().unwrap_or_default()
634            )
635        };
636        let mut new_properties = BTreeMap::from([
637            ("table.name".to_owned(), full_table_name),
638            ("connector".to_owned(), "snowflake_v2".to_owned()),
639            (
640                "jdbc.url".to_owned(),
641                config.jdbc_url.clone().unwrap_or_default(),
642            ),
643            ("type".to_owned(), "append-only".to_owned()),
644            (
645                "primary_key".to_owned(),
646                properties.get("primary_key").cloned().unwrap_or_default(),
647            ),
648            (
649                "schema.name".to_owned(),
650                config.snowflake_schema.clone().unwrap_or_default(),
651            ),
652            (
653                "database.name".to_owned(),
654                config.snowflake_database.clone().unwrap_or_default(),
655            ),
656        ]);
657
658        // Reuse build_jdbc_connection_properties to get driver properties (auth, user, etc.)
659        let (_jdbc_url, connection_properties) = config.build_jdbc_connection_properties()?;
660        for (key, value) in connection_properties {
661            new_properties.insert(key, value);
662        }
663
664        param.properties = new_properties;
665
666        let jdbc_sink_writer =
667            CoordinatedRemoteSinkWriter::new(param.clone(), metrics.clone()).await?;
668        Ok(Self {
669            augmented_row: AugmentedChunk::new(0, is_append_only),
670            jdbc_sink_writer,
671        })
672    }
673}
674
675impl SnowflakeSinkJdbcWriter {
676    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
677        self.augmented_row.reset_epoch(epoch);
678        self.jdbc_sink_writer.begin_epoch(epoch).await?;
679        Ok(())
680    }
681
682    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
683        let chunk = self.augmented_row.augmented_chunk(chunk)?;
684        self.jdbc_sink_writer.write_batch(chunk).await?;
685        Ok(())
686    }
687
688    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
689        self.jdbc_sink_writer.barrier(is_checkpoint).await?;
690        Ok(())
691    }
692
693    async fn abort(&mut self) -> Result<()> {
694        // TODO: abort should clean up all the data written in this epoch.
695        self.jdbc_sink_writer.abort().await?;
696        Ok(())
697    }
698}
699
700#[derive(Default)]
701pub struct SnowflakeTaskContext {
702    // required for task creation
703    pub target_table_name: String,
704    pub database: String,
705    pub schema_name: String,
706    pub schema: Schema,
707
708    // only upsert
709    pub task_name: Option<String>,
710    pub cdc_table_name: Option<String>,
711    pub schedule_seconds: u64,
712    pub warehouse: Option<String>,
713    pub pk_column_names: Option<Vec<String>>,
714    pub all_column_names: Option<Vec<String>>,
715
716    // only s3 writer
717    pub stage: Option<String>,
718    pub pipe_name: Option<String>,
719}
720pub struct SnowflakeSinkCommitter {
721    client: Option<SnowflakeJniClient>,
722}
723
724impl SnowflakeSinkCommitter {
725    pub fn new(
726        config: SnowflakeV2Config,
727        schema: &Schema,
728        pk_indices: &Vec<usize>,
729        is_append_only: bool,
730    ) -> Result<Self> {
731        let client = if let Some((snowflake_task_ctx, client)) =
732            config.build_snowflake_task_ctx_jdbc_client(is_append_only, schema, pk_indices)?
733        {
734            Some(SnowflakeJniClient::new(client, snowflake_task_ctx))
735        } else {
736            None
737        };
738        Ok(Self { client })
739    }
740}
741
742#[async_trait]
743impl SinglePhaseCommitCoordinator for SnowflakeSinkCommitter {
744    async fn init(&mut self) -> Result<()> {
745        if let Some(client) = &self.client {
746            // Todo: move this to validate
747            client.execute_create_pipe().await?;
748            client.execute_create_merge_into_task().await?;
749        }
750        Ok(())
751    }
752
753    async fn commit(
754        &mut self,
755        _epoch: u64,
756        _metadata: Vec<SinkMetadata>,
757        schema_change: Option<PbSinkSchemaChange>,
758    ) -> Result<()> {
759        let client = self.client.as_mut().ok_or_else(|| {
760            SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
761        })?;
762        client.execute_flush_pipe().await?;
763
764        if let Some(schema_change) = schema_change {
765            use risingwave_pb::stream_plan::sink_schema_change::PbOp as SinkSchemaChangeOp;
766            let schema_change_op = schema_change.op.ok_or_else(|| {
767                SinkError::Coordinator(anyhow!("Invalid schema change operation"))
768            })?;
769            let SinkSchemaChangeOp::AddColumns(add_columns) = schema_change_op else {
770                return Err(SinkError::Coordinator(anyhow!(
771                    "Only AddColumns schema change is supported for Snowflake sink"
772                )));
773            };
774            client
775                .execute_alter_add_columns(
776                    &add_columns
777                        .fields
778                        .into_iter()
779                        .map(|f| (f.name, DataType::from(f.data_type.unwrap()).to_string()))
780                        .collect_vec(),
781                )
782                .await?;
783        }
784        Ok(())
785    }
786}
787
788impl Drop for SnowflakeSinkCommitter {
789    fn drop(&mut self) {
790        if let Some(client) = self.client.take() {
791            tokio::spawn(async move {
792                client.execute_drop_task().await.ok();
793            });
794        }
795    }
796}
797
798pub struct SnowflakeJniClient {
799    jdbc_client: JdbcJniClient,
800    snowflake_task_context: SnowflakeTaskContext,
801}
802
803impl SnowflakeJniClient {
804    pub fn new(jdbc_client: JdbcJniClient, snowflake_task_context: SnowflakeTaskContext) -> Self {
805        Self {
806            jdbc_client,
807            snowflake_task_context,
808        }
809    }
810
811    pub async fn execute_alter_add_columns(
812        &mut self,
813        columns: &Vec<(String, String)>,
814    ) -> Result<()> {
815        self.execute_drop_task().await?;
816        if let Some(names) = self.snowflake_task_context.all_column_names.as_mut() {
817            names.extend(columns.iter().map(|(name, _)| name.clone()));
818        }
819        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
820            let alter_add_column_cdc_table_sql = build_alter_add_column_sql(
821                cdc_table_name,
822                &self.snowflake_task_context.database,
823                &self.snowflake_task_context.schema_name,
824                columns,
825            );
826            self.jdbc_client
827                .execute_sql_sync(vec![alter_add_column_cdc_table_sql])
828                .await?;
829        }
830
831        let alter_add_column_target_table_sql = build_alter_add_column_sql(
832            &self.snowflake_task_context.target_table_name,
833            &self.snowflake_task_context.database,
834            &self.snowflake_task_context.schema_name,
835            columns,
836        );
837        self.jdbc_client
838            .execute_sql_sync(vec![alter_add_column_target_table_sql])
839            .await?;
840
841        self.execute_create_merge_into_task().await?;
842        Ok(())
843    }
844
845    pub async fn execute_create_merge_into_task(&self) -> Result<()> {
846        if self.snowflake_task_context.task_name.is_some() {
847            let create_task_sql = build_create_merge_into_task_sql(&self.snowflake_task_context);
848            let start_task_sql = build_start_task_sql(&self.snowflake_task_context);
849            self.jdbc_client
850                .execute_sql_sync(vec![create_task_sql])
851                .await?;
852            self.jdbc_client
853                .execute_sql_sync(vec![start_task_sql])
854                .await?;
855        }
856        Ok(())
857    }
858
859    pub async fn execute_drop_task(&self) -> Result<()> {
860        if self.snowflake_task_context.task_name.is_some() {
861            let sql = build_drop_task_sql(&self.snowflake_task_context);
862            if let Err(e) = self.jdbc_client.execute_sql_sync(vec![sql]).await {
863                tracing::error!(
864                    "Failed to drop Snowflake sink task {:?}: {:?}",
865                    self.snowflake_task_context.task_name,
866                    e.as_report()
867                );
868            } else {
869                tracing::info!(
870                    "Snowflake sink task {:?} dropped",
871                    self.snowflake_task_context.task_name
872                );
873            }
874        }
875        Ok(())
876    }
877
878    pub async fn execute_create_table(&self) -> Result<()> {
879        // create target table
880        let create_target_table_sql = build_create_table_sql(
881            &self.snowflake_task_context.target_table_name,
882            &self.snowflake_task_context.database,
883            &self.snowflake_task_context.schema_name,
884            &self.snowflake_task_context.schema,
885            false,
886        )?;
887        self.jdbc_client
888            .execute_sql_sync(vec![create_target_table_sql])
889            .await?;
890        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
891            let create_cdc_table_sql = build_create_table_sql(
892                cdc_table_name,
893                &self.snowflake_task_context.database,
894                &self.snowflake_task_context.schema_name,
895                &self.snowflake_task_context.schema,
896                true,
897            )?;
898            self.jdbc_client
899                .execute_sql_sync(vec![create_cdc_table_sql])
900                .await?;
901        }
902        Ok(())
903    }
904
905    pub async fn execute_create_pipe(&self) -> Result<()> {
906        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
907            let table_name =
908                if let Some(table_name) = self.snowflake_task_context.cdc_table_name.as_ref() {
909                    table_name
910                } else {
911                    &self.snowflake_task_context.target_table_name
912                };
913            let create_pipe_sql = build_create_pipe_sql(
914                table_name,
915                &self.snowflake_task_context.database,
916                &self.snowflake_task_context.schema_name,
917                self.snowflake_task_context.stage.as_ref().ok_or_else(|| {
918                    SinkError::Config(anyhow!("snowflake.stage is required for S3 writer"))
919                })?,
920                pipe_name,
921                &self.snowflake_task_context.target_table_name,
922            );
923            self.jdbc_client
924                .execute_sql_sync(vec![create_pipe_sql])
925                .await?;
926        }
927        Ok(())
928    }
929
930    pub async fn execute_flush_pipe(&self) -> Result<()> {
931        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
932            let flush_pipe_sql = build_flush_pipe_sql(
933                &self.snowflake_task_context.database,
934                &self.snowflake_task_context.schema_name,
935                pipe_name,
936            );
937            self.jdbc_client
938                .execute_sql_sync(vec![flush_pipe_sql])
939                .await?;
940        }
941        Ok(())
942    }
943}
944
945fn build_create_table_sql(
946    table_name: &str,
947    database: &str,
948    schema_name: &str,
949    schema: &Schema,
950    need_op_and_row_id: bool,
951) -> Result<String> {
952    let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema_name, table_name);
953    let mut columns: Vec<String> = schema
954        .fields
955        .iter()
956        .map(|field| {
957            let data_type = convert_snowflake_data_type(&field.data_type)?;
958            Ok(format!(r#""{}" {}"#, field.name, data_type))
959        })
960        .collect::<Result<Vec<String>>>()?;
961    if need_op_and_row_id {
962        columns.push(format!(r#""{}" STRING"#, SNOWFLAKE_SINK_ROW_ID));
963        columns.push(format!(r#""{}" INT"#, SNOWFLAKE_SINK_OP));
964    }
965    let columns_str = columns.join(", ");
966    Ok(format!(
967        "CREATE TABLE IF NOT EXISTS {} ({}) ENABLE_SCHEMA_EVOLUTION  = true",
968        full_table_name, columns_str
969    ))
970}
971
972fn convert_snowflake_data_type(data_type: &DataType) -> Result<String> {
973    let data_type = match data_type {
974        DataType::Int16 => "SMALLINT".to_owned(),
975        DataType::Int32 => "INTEGER".to_owned(),
976        DataType::Int64 => "BIGINT".to_owned(),
977        DataType::Float32 => "FLOAT4".to_owned(),
978        DataType::Float64 => "FLOAT8".to_owned(),
979        DataType::Boolean => "BOOLEAN".to_owned(),
980        DataType::Varchar => "STRING".to_owned(),
981        DataType::Date => "DATE".to_owned(),
982        DataType::Timestamp => "TIMESTAMP".to_owned(),
983        DataType::Timestamptz => "TIMESTAMP_TZ".to_owned(),
984        DataType::Jsonb => "STRING".to_owned(),
985        DataType::Decimal => "DECIMAL".to_owned(),
986        DataType::Bytea => "BINARY".to_owned(),
987        DataType::Time => "TIME".to_owned(),
988        _ => {
989            return Err(SinkError::Config(anyhow!(
990                "Dont support auto create table for datatype: {}",
991                data_type
992            )));
993        }
994    };
995    Ok(data_type)
996}
997
998fn build_create_pipe_sql(
999    table_name: &str,
1000    database: &str,
1001    schema: &str,
1002    stage: &str,
1003    pipe_name: &str,
1004    target_table_name: &str,
1005) -> String {
1006    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
1007    let stage = format!(
1008        r#""{}"."{}"."{}"/{}"#,
1009        database, schema, stage, target_table_name
1010    );
1011    let table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
1012    format!(
1013        "CREATE OR REPLACE PIPE {} AUTO_INGEST = FALSE AS COPY INTO {} FROM @{} MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE FILE_FORMAT = (type = 'JSON');",
1014        pipe_name, table_name, stage
1015    )
1016}
1017
1018fn build_flush_pipe_sql(database: &str, schema: &str, pipe_name: &str) -> String {
1019    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
1020    format!("ALTER PIPE {} REFRESH;", pipe_name,)
1021}
1022
1023fn build_alter_add_column_sql(
1024    table_name: &str,
1025    database: &str,
1026    schema: &str,
1027    columns: &Vec<(String, String)>,
1028) -> String {
1029    let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
1030    jdbc_jni_client::build_alter_add_column_sql(&full_table_name, columns, true)
1031}
1032
1033fn build_start_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1034    let SnowflakeTaskContext {
1035        task_name,
1036        database,
1037        schema_name: schema,
1038        ..
1039    } = snowflake_task_context;
1040    let full_task_name = format!(
1041        r#""{}"."{}"."{}""#,
1042        database,
1043        schema,
1044        task_name.as_ref().unwrap()
1045    );
1046    format!("ALTER TASK {} RESUME", full_task_name)
1047}
1048
1049fn build_drop_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1050    let SnowflakeTaskContext {
1051        task_name,
1052        database,
1053        schema_name: schema,
1054        ..
1055    } = snowflake_task_context;
1056    let full_task_name = format!(
1057        r#""{}"."{}"."{}""#,
1058        database,
1059        schema,
1060        task_name.as_ref().unwrap()
1061    );
1062    format!("DROP TASK IF EXISTS {}", full_task_name)
1063}
1064
1065fn build_create_merge_into_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1066    let SnowflakeTaskContext {
1067        task_name,
1068        cdc_table_name,
1069        target_table_name,
1070        schedule_seconds,
1071        warehouse,
1072        pk_column_names,
1073        all_column_names,
1074        database,
1075        schema_name,
1076        ..
1077    } = snowflake_task_context;
1078    let full_task_name = format!(
1079        r#""{}"."{}"."{}""#,
1080        database,
1081        schema_name,
1082        task_name.as_ref().unwrap()
1083    );
1084    let full_cdc_table_name = format!(
1085        r#""{}"."{}"."{}""#,
1086        database,
1087        schema_name,
1088        cdc_table_name.as_ref().unwrap()
1089    );
1090    let full_target_table_name = format!(
1091        r#""{}"."{}"."{}""#,
1092        database, schema_name, target_table_name
1093    );
1094
1095    let pk_names_str = pk_column_names
1096        .as_ref()
1097        .unwrap()
1098        .iter()
1099        .map(|name| format!(r#""{}""#, name))
1100        .collect::<Vec<String>>()
1101        .join(", ");
1102    let pk_names_eq_str = pk_column_names
1103        .as_ref()
1104        .unwrap()
1105        .iter()
1106        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1107        .collect::<Vec<String>>()
1108        .join(" AND ");
1109    let all_column_names_set_str = all_column_names
1110        .as_ref()
1111        .unwrap()
1112        .iter()
1113        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1114        .collect::<Vec<String>>()
1115        .join(", ");
1116    let all_column_names_str = all_column_names
1117        .as_ref()
1118        .unwrap()
1119        .iter()
1120        .map(|name| format!(r#""{}""#, name))
1121        .collect::<Vec<String>>()
1122        .join(", ");
1123    let all_column_names_insert_str = all_column_names
1124        .as_ref()
1125        .unwrap()
1126        .iter()
1127        .map(|name| format!(r#"source."{}""#, name))
1128        .collect::<Vec<String>>()
1129        .join(", ");
1130
1131    format!(
1132        r#"CREATE OR REPLACE TASK {task_name}
1133WAREHOUSE = {warehouse}
1134SCHEDULE = '{schedule_seconds} SECONDS'
1135AS
1136BEGIN
1137    LET max_row_id STRING;
1138
1139    SELECT COALESCE(MAX("{snowflake_sink_row_id}"), '0') INTO :max_row_id
1140    FROM {cdc_table_name};
1141
1142    MERGE INTO {target_table_name} AS target
1143    USING (
1144        SELECT *
1145        FROM (
1146            SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_names_str} ORDER BY "{snowflake_sink_row_id}" DESC) AS dedupe_id
1147            FROM {cdc_table_name}
1148            WHERE "{snowflake_sink_row_id}" <= :max_row_id
1149        ) AS subquery
1150        WHERE dedupe_id = 1
1151    ) AS source
1152    ON {pk_names_eq_str}
1153    WHEN MATCHED AND source."{snowflake_sink_op}" IN (2, 4) THEN DELETE
1154    WHEN MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN UPDATE SET {all_column_names_set_str}
1155    WHEN NOT MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN INSERT ({all_column_names_str}) VALUES ({all_column_names_insert_str});
1156
1157    DELETE FROM {cdc_table_name}
1158    WHERE "{snowflake_sink_row_id}" <= :max_row_id;
1159END;"#,
1160        task_name = full_task_name,
1161        warehouse = warehouse.as_ref().unwrap(),
1162        schedule_seconds = schedule_seconds,
1163        cdc_table_name = full_cdc_table_name,
1164        target_table_name = full_target_table_name,
1165        pk_names_str = pk_names_str,
1166        pk_names_eq_str = pk_names_eq_str,
1167        all_column_names_set_str = all_column_names_set_str,
1168        all_column_names_str = all_column_names_str,
1169        all_column_names_insert_str = all_column_names_insert_str,
1170        snowflake_sink_row_id = SNOWFLAKE_SINK_ROW_ID,
1171        snowflake_sink_op = SNOWFLAKE_SINK_OP,
1172    )
1173}
1174
1175#[cfg(test)]
1176mod tests {
1177    use std::collections::BTreeMap;
1178
1179    use super::*;
1180    use crate::sink::jdbc_jni_client::normalize_sql;
1181
1182    fn base_properties() -> BTreeMap<String, String> {
1183        BTreeMap::from([
1184            ("type".to_owned(), "append-only".to_owned()),
1185            ("jdbc.url".to_owned(), "jdbc:snowflake://account".to_owned()),
1186            ("username".to_owned(), "RW_USER".to_owned()),
1187        ])
1188    }
1189
1190    #[test]
1191    fn test_build_jdbc_props_password() {
1192        let mut props = base_properties();
1193        props.insert("password".to_owned(), "secret".to_owned());
1194        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1195        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1196        assert_eq!(url, "jdbc:snowflake://account");
1197        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1198        assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1199        assert_eq!(map.get("password"), Some(&"secret".to_owned()));
1200        assert!(!map.contains_key("authenticator"));
1201    }
1202
1203    #[test]
1204    fn test_build_jdbc_props_key_pair_file() {
1205        let mut props = base_properties();
1206        props.insert(
1207            "auth.method".to_owned(),
1208            AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
1209        );
1210        props.insert("private_key_file".to_owned(), "/tmp/rsa_key.p8".to_owned());
1211        props.insert("private_key_file_pwd".to_owned(), "dummy".to_owned());
1212        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1213        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1214        assert_eq!(url, "jdbc:snowflake://account");
1215        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1216        assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1217        assert_eq!(
1218            map.get("private_key_file"),
1219            Some(&"/tmp/rsa_key.p8".to_owned())
1220        );
1221        assert_eq!(map.get("private_key_file_pwd"), Some(&"dummy".to_owned()));
1222    }
1223
1224    #[test]
1225    fn test_build_jdbc_props_key_pair_object() {
1226        let mut props = base_properties();
1227        props.insert(
1228            "auth.method".to_owned(),
1229            AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
1230        );
1231        props.insert(
1232            "private_key_pem".to_owned(),
1233            "-----BEGIN PRIVATE KEY-----
1234...
1235-----END PRIVATE KEY-----"
1236                .to_owned(),
1237        );
1238        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1239        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1240        assert_eq!(url, "jdbc:snowflake://account");
1241        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1242        assert_eq!(
1243            map.get("private_key_pem"),
1244            Some(
1245                &"-----BEGIN PRIVATE KEY-----
1246...
1247-----END PRIVATE KEY-----"
1248                    .to_owned()
1249            )
1250        );
1251        assert!(!map.contains_key("private_key_file"));
1252    }
1253
1254    #[test]
1255    fn test_snowflake_sink_commit_coordinator() {
1256        let snowflake_task_context = SnowflakeTaskContext {
1257            task_name: Some("test_task".to_owned()),
1258            cdc_table_name: Some("test_cdc_table".to_owned()),
1259            target_table_name: "test_target_table".to_owned(),
1260            schedule_seconds: 3600,
1261            warehouse: Some("test_warehouse".to_owned()),
1262            pk_column_names: Some(vec!["v1".to_owned()]),
1263            all_column_names: Some(vec!["v1".to_owned(), "v2".to_owned()]),
1264            database: "test_db".to_owned(),
1265            schema_name: "test_schema".to_owned(),
1266            schema: Schema { fields: vec![] },
1267            stage: None,
1268            pipe_name: None,
1269        };
1270        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1271        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task"
1272WAREHOUSE = test_warehouse
1273SCHEDULE = '3600 SECONDS'
1274AS
1275BEGIN
1276    LET max_row_id STRING;
1277
1278    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1279    FROM "test_db"."test_schema"."test_cdc_table";
1280
1281    MERGE INTO "test_db"."test_schema"."test_target_table" AS target
1282    USING (
1283        SELECT *
1284        FROM (
1285            SELECT *, ROW_NUMBER() OVER (PARTITION BY "v1" ORDER BY "__row_id" DESC) AS dedupe_id
1286            FROM "test_db"."test_schema"."test_cdc_table"
1287            WHERE "__row_id" <= :max_row_id
1288        ) AS subquery
1289        WHERE dedupe_id = 1
1290    ) AS source
1291    ON target."v1" = source."v1"
1292    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1293    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."v1" = source."v1", target."v2" = source."v2"
1294    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("v1", "v2") VALUES (source."v1", source."v2");
1295
1296    DELETE FROM "test_db"."test_schema"."test_cdc_table"
1297    WHERE "__row_id" <= :max_row_id;
1298END;"#;
1299        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1300    }
1301
1302    #[test]
1303    fn test_snowflake_sink_commit_coordinator_multi_pk() {
1304        let snowflake_task_context = SnowflakeTaskContext {
1305            task_name: Some("test_task_multi_pk".to_owned()),
1306            cdc_table_name: Some("cdc_multi_pk".to_owned()),
1307            target_table_name: "target_multi_pk".to_owned(),
1308            schedule_seconds: 300,
1309            warehouse: Some("multi_pk_warehouse".to_owned()),
1310            pk_column_names: Some(vec!["id1".to_owned(), "id2".to_owned()]),
1311            all_column_names: Some(vec!["id1".to_owned(), "id2".to_owned(), "val".to_owned()]),
1312            database: "test_db".to_owned(),
1313            schema_name: "test_schema".to_owned(),
1314            schema: Schema { fields: vec![] },
1315            stage: None,
1316            pipe_name: None,
1317        };
1318        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1319        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task_multi_pk"
1320WAREHOUSE = multi_pk_warehouse
1321SCHEDULE = '300 SECONDS'
1322AS
1323BEGIN
1324    LET max_row_id STRING;
1325
1326    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1327    FROM "test_db"."test_schema"."cdc_multi_pk";
1328
1329    MERGE INTO "test_db"."test_schema"."target_multi_pk" AS target
1330    USING (
1331        SELECT *
1332        FROM (
1333            SELECT *, ROW_NUMBER() OVER (PARTITION BY "id1", "id2" ORDER BY "__row_id" DESC) AS dedupe_id
1334            FROM "test_db"."test_schema"."cdc_multi_pk"
1335            WHERE "__row_id" <= :max_row_id
1336        ) AS subquery
1337        WHERE dedupe_id = 1
1338    ) AS source
1339    ON target."id1" = source."id1" AND target."id2" = source."id2"
1340    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1341    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."id1" = source."id1", target."id2" = source."id2", target."val" = source."val"
1342    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("id1", "id2", "val") VALUES (source."id1", source."id2", source."val");
1343
1344    DELETE FROM "test_db"."test_schema"."cdc_multi_pk"
1345    WHERE "__row_id" <= :max_row_id;
1346END;"#;
1347        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1348    }
1349}