risingwave_connector/sink/snowflake_redshift/
snowflake.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::num::NonZeroU64;
16use std::collections::BTreeMap;
17
18use anyhow::anyhow;
19use phf::{Set, phf_set};
20use risingwave_common::array::StreamChunk;
21use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
22use risingwave_common::types::DataType;
23use risingwave_pb::connector_service::{SinkMetadata, sink_metadata};
24use serde::Deserialize;
25use serde_with::{DisplayFromStr, serde_as};
26use thiserror_ext::AsReport;
27use tokio::sync::mpsc::UnboundedSender;
28use tonic::async_trait;
29use with_options::WithOptions;
30
31use crate::connector_common::IcebergSinkCompactionUpdate;
32use crate::enforce_secret::EnforceSecret;
33use crate::sink::coordinate::CoordinatedLogSinker;
34use crate::sink::decouple_checkpoint_log_sink::default_commit_checkpoint_interval;
35use crate::sink::file_sink::s3::S3Common;
36use crate::sink::jdbc_jni_client::{self, JdbcJniClient};
37use crate::sink::remote::CoordinatedRemoteSinkWriter;
38use crate::sink::snowflake_redshift::{AugmentedChunk, SnowflakeRedshiftSinkS3Writer};
39use crate::sink::writer::SinkWriter;
40use crate::sink::{
41    Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT,
42    SinglePhaseCommitCoordinator, Sink, SinkCommitCoordinator, SinkError, SinkParam,
43    SinkWriterMetrics, SinkWriterParam,
44};
45
46pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2";
47pub const SNOWFLAKE_SINK_ROW_ID: &str = "__row_id";
48pub const SNOWFLAKE_SINK_OP: &str = "__op";
49
50const AUTH_METHOD_PASSWORD: &str = "password";
51const AUTH_METHOD_KEY_PAIR_FILE: &str = "key_pair_file";
52const AUTH_METHOD_KEY_PAIR_OBJECT: &str = "key_pair_object";
53const PROP_AUTH_METHOD: &str = "auth.method";
54
55#[serde_as]
56#[derive(Debug, Clone, Deserialize, WithOptions)]
57pub struct SnowflakeV2Config {
58    #[serde(rename = "type")]
59    pub r#type: String,
60
61    #[serde(rename = "intermediate.table.name")]
62    pub snowflake_cdc_table_name: Option<String>,
63
64    #[serde(rename = "table.name")]
65    pub snowflake_target_table_name: Option<String>,
66
67    #[serde(rename = "database")]
68    pub snowflake_database: Option<String>,
69
70    #[serde(rename = "schema")]
71    pub snowflake_schema: Option<String>,
72
73    #[serde(default = "default_schedule")]
74    #[serde(rename = "write.target.interval.seconds")]
75    #[serde_as(as = "DisplayFromStr")]
76    pub snowflake_schedule_seconds: u64,
77
78    #[serde(rename = "warehouse")]
79    pub snowflake_warehouse: Option<String>,
80
81    #[serde(rename = "jdbc.url")]
82    pub jdbc_url: Option<String>,
83
84    #[serde(rename = "username")]
85    pub username: Option<String>,
86
87    #[serde(rename = "password")]
88    pub password: Option<String>,
89
90    // Authentication method control (password | key_pair_file | key_pair_object)
91    #[serde(rename = "auth.method")]
92    pub auth_method: Option<String>,
93
94    // Key-pair authentication via connection Properties (Option 2: file-based)
95    #[serde(rename = "private_key_file")]
96    pub private_key_file: Option<String>,
97
98    #[serde(rename = "private_key_file_pwd")]
99    pub private_key_file_pwd: Option<String>,
100
101    // Key-pair authentication via connection Properties (Option 1: object-based, PEM content)
102    #[serde(rename = "private_key_pem")]
103    pub private_key_pem: Option<String>,
104
105    /// Commit every n(>0) checkpoints, default is 10.
106    #[serde(default = "default_commit_checkpoint_interval")]
107    #[serde_as(as = "DisplayFromStr")]
108    #[with_option(allow_alter_on_fly)]
109    pub commit_checkpoint_interval: u64,
110
111    /// Enable auto schema change for upsert sink.
112    /// If enabled, the sink will automatically alter the target table to add new columns.
113    #[serde(default)]
114    #[serde(rename = "auto.schema.change")]
115    #[serde_as(as = "DisplayFromStr")]
116    pub auto_schema_change: bool,
117
118    #[serde(default)]
119    #[serde(rename = "create_table_if_not_exists")]
120    #[serde_as(as = "DisplayFromStr")]
121    pub create_table_if_not_exists: bool,
122
123    #[serde(default = "default_with_s3")]
124    #[serde(rename = "with_s3")]
125    #[serde_as(as = "DisplayFromStr")]
126    pub with_s3: bool,
127
128    #[serde(flatten)]
129    pub s3_inner: Option<S3Common>,
130
131    #[serde(rename = "stage")]
132    pub stage: Option<String>,
133}
134
135fn default_schedule() -> u64 {
136    3600 // Default to 1 hour
137}
138
139fn default_with_s3() -> bool {
140    true
141}
142
143impl SnowflakeV2Config {
144    /// Build JDBC Properties for the Snowflake JDBC connection (no URL parameters).
145    /// Returns (`jdbc_url`, `driver_properties`).
146    /// - `driver_properties` are transformed/used by the Java runner and passed to `DriverManager::getConnection(url, props)`
147    ///
148    /// Note: This method assumes the config has been validated by `from_btreemap`.
149    pub fn build_jdbc_connection_properties(&self) -> Result<(String, Vec<(String, String)>)> {
150        let jdbc_url = self
151            .jdbc_url
152            .clone()
153            .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?;
154        let username = self
155            .username
156            .clone()
157            .ok_or(SinkError::Config(anyhow!("username is required")))?;
158
159        let mut connection_properties: Vec<(String, String)> = vec![("user".to_owned(), username)];
160
161        // auth_method is guaranteed to be Some after validation in from_btreemap
162        match self.auth_method.as_deref().unwrap() {
163            AUTH_METHOD_PASSWORD => {
164                // password is guaranteed to exist by from_btreemap validation
165                connection_properties.push(("password".to_owned(), self.password.clone().unwrap()));
166            }
167            AUTH_METHOD_KEY_PAIR_FILE => {
168                // private_key_file is guaranteed to exist by from_btreemap validation
169                connection_properties.push((
170                    "private_key_file".to_owned(),
171                    self.private_key_file.clone().unwrap(),
172                ));
173                if let Some(pwd) = self.private_key_file_pwd.clone() {
174                    connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
175                }
176            }
177            AUTH_METHOD_KEY_PAIR_OBJECT => {
178                connection_properties.push((
179                    PROP_AUTH_METHOD.to_owned(),
180                    AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
181                ));
182                // private_key_pem is guaranteed to exist by from_btreemap validation
183                connection_properties.push((
184                    "private_key_pem".to_owned(),
185                    self.private_key_pem.clone().unwrap(),
186                ));
187                if let Some(pwd) = self.private_key_file_pwd.clone() {
188                    connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
189                }
190            }
191            _ => {
192                // This should never happen since from_btreemap validates auth_method
193                unreachable!(
194                    "Invalid auth_method - should have been caught during config validation"
195                )
196            }
197        }
198
199        Ok((jdbc_url, connection_properties))
200    }
201
202    pub fn from_btreemap(properties: &BTreeMap<String, String>) -> Result<Self> {
203        let mut config =
204            serde_json::from_value::<SnowflakeV2Config>(serde_json::to_value(properties).unwrap())
205                .map_err(|e| SinkError::Config(anyhow!(e)))?;
206        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
207            return Err(SinkError::Config(anyhow!(
208                "`{}` must be {}, or {}",
209                SINK_TYPE_OPTION,
210                SINK_TYPE_APPEND_ONLY,
211                SINK_TYPE_UPSERT
212            )));
213        }
214
215        // Normalize and validate authentication method
216        let has_password = config.password.is_some();
217        let has_file = config.private_key_file.is_some();
218        let has_pem = config.private_key_pem.as_deref().is_some();
219
220        let normalized_auth_method = match config
221            .auth_method
222            .as_deref()
223            .map(|s| s.trim().to_ascii_lowercase())
224        {
225            Some(method) if method == AUTH_METHOD_PASSWORD => {
226                if !has_password {
227                    return Err(SinkError::Config(anyhow!(
228                        "auth.method=password requires `password`"
229                    )));
230                }
231                if has_file || has_pem {
232                    return Err(SinkError::Config(anyhow!(
233                        "auth.method=password must not set `private_key_file`/`private_key_pem`"
234                    )));
235                }
236                AUTH_METHOD_PASSWORD.to_owned()
237            }
238            Some(method) if method == AUTH_METHOD_KEY_PAIR_FILE => {
239                if !has_file {
240                    return Err(SinkError::Config(anyhow!(
241                        "auth.method=key_pair_file requires `private_key_file`"
242                    )));
243                }
244                if has_password {
245                    return Err(SinkError::Config(anyhow!(
246                        "auth.method=key_pair_file must not set `password`"
247                    )));
248                }
249                if has_pem {
250                    return Err(SinkError::Config(anyhow!(
251                        "auth.method=key_pair_file must not set `private_key_pem`"
252                    )));
253                }
254                AUTH_METHOD_KEY_PAIR_FILE.to_owned()
255            }
256            Some(method) if method == AUTH_METHOD_KEY_PAIR_OBJECT => {
257                if !has_pem {
258                    return Err(SinkError::Config(anyhow!(
259                        "auth.method=key_pair_object requires `private_key_pem`"
260                    )));
261                }
262                if has_password {
263                    return Err(SinkError::Config(anyhow!(
264                        "auth.method=key_pair_object must not set `password`"
265                    )));
266                }
267                AUTH_METHOD_KEY_PAIR_OBJECT.to_owned()
268            }
269            Some(other) => {
270                return Err(SinkError::Config(anyhow!(
271                    "invalid auth.method: {} (allowed: password | key_pair_file | key_pair_object)",
272                    other
273                )));
274            }
275            None => {
276                // Infer auth method from supplied fields
277                match (has_password, has_file, has_pem) {
278                    (true, false, false) => AUTH_METHOD_PASSWORD.to_owned(),
279                    (false, true, false) => AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
280                    (false, false, true) => AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
281                    (true, true, _) | (true, _, true) | (false, true, true) => {
282                        return Err(SinkError::Config(anyhow!(
283                            "ambiguous auth: multiple auth options provided; remove one or set `auth.method`"
284                        )));
285                    }
286                    _ => {
287                        return Err(SinkError::Config(anyhow!(
288                            "no authentication configured: set either `password`, or `private_key_file`, or `private_key_pem` (or provide `auth.method`)"
289                        )));
290                    }
291                }
292            }
293        };
294        config.auth_method = Some(normalized_auth_method);
295        Ok(config)
296    }
297
298    pub fn build_snowflake_task_ctx_jdbc_client(
299        &self,
300        is_append_only: bool,
301        schema: &Schema,
302        pk_indices: &Vec<usize>,
303    ) -> Result<Option<(SnowflakeTaskContext, JdbcJniClient)>> {
304        if !self.auto_schema_change && is_append_only && !self.create_table_if_not_exists {
305            // append-only + no auto schema change is not need to create a client
306            return Ok(None);
307        }
308        let target_table_name = self
309            .snowflake_target_table_name
310            .clone()
311            .ok_or(SinkError::Config(anyhow!("table.name is required")))?;
312        let database = self
313            .snowflake_database
314            .clone()
315            .ok_or(SinkError::Config(anyhow!("database is required")))?;
316        let schema_name = self
317            .snowflake_schema
318            .clone()
319            .ok_or(SinkError::Config(anyhow!("schema is required")))?;
320        let mut snowflake_task_ctx = SnowflakeTaskContext {
321            target_table_name: target_table_name.clone(),
322            database,
323            schema_name,
324            schema: schema.clone(),
325            ..Default::default()
326        };
327
328        let (jdbc_url, connection_properties) = self.build_jdbc_connection_properties()?;
329        let client = JdbcJniClient::new_with_props(jdbc_url, connection_properties)?;
330
331        if self.with_s3 {
332            let stage = self
333                .stage
334                .clone()
335                .ok_or(SinkError::Config(anyhow!("stage is required")))?;
336            snowflake_task_ctx.stage = Some(stage);
337            snowflake_task_ctx.pipe_name = Some(format!("{}_pipe", target_table_name));
338        }
339        if !is_append_only {
340            let cdc_table_name = self
341                .snowflake_cdc_table_name
342                .clone()
343                .ok_or(SinkError::Config(anyhow!(
344                    "intermediate.table.name is required"
345                )))?;
346            snowflake_task_ctx.cdc_table_name = Some(cdc_table_name.clone());
347            snowflake_task_ctx.schedule_seconds = self.snowflake_schedule_seconds;
348            snowflake_task_ctx.warehouse = Some(
349                self.snowflake_warehouse
350                    .clone()
351                    .ok_or(SinkError::Config(anyhow!("warehouse is required")))?,
352            );
353            let pk_column_names: Vec<_> = schema
354                .fields
355                .iter()
356                .enumerate()
357                .filter(|(index, _)| pk_indices.contains(index))
358                .map(|(_, field)| field.name.clone())
359                .collect();
360            if pk_column_names.is_empty() {
361                return Err(SinkError::Config(anyhow!(
362                    "Primary key columns not found. Please set the `primary_key` column in the sink properties, or ensure that the sink contains the primary key columns from the upstream."
363                )));
364            }
365            snowflake_task_ctx.pk_column_names = Some(pk_column_names);
366            snowflake_task_ctx.all_column_names = Some(
367                schema
368                    .fields
369                    .iter()
370                    .map(|field| field.name.clone())
371                    .collect(),
372            );
373            snowflake_task_ctx.task_name = Some(format!(
374                "rw_snowflake_sink_from_{cdc_table_name}_to_{target_table_name}"
375            ));
376        }
377        Ok(Some((snowflake_task_ctx, client)))
378    }
379}
380
381impl EnforceSecret for SnowflakeV2Config {
382    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
383        "username",
384        "password",
385        "jdbc.url",
386        // Key-pair authentication secrets
387        "private_key_file_pwd",
388        "private_key_pem",
389    };
390}
391
392#[derive(Clone, Debug)]
393pub struct SnowflakeV2Sink {
394    config: SnowflakeV2Config,
395    schema: Schema,
396    pk_indices: Vec<usize>,
397    is_append_only: bool,
398    param: SinkParam,
399}
400
401impl EnforceSecret for SnowflakeV2Sink {
402    fn enforce_secret<'a>(
403        prop_iter: impl Iterator<Item = &'a str>,
404    ) -> crate::sink::ConnectorResult<()> {
405        for prop in prop_iter {
406            SnowflakeV2Config::enforce_one(prop)?;
407        }
408        Ok(())
409    }
410}
411
412impl TryFrom<SinkParam> for SnowflakeV2Sink {
413    type Error = SinkError;
414
415    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
416        let schema = param.schema();
417        let config = SnowflakeV2Config::from_btreemap(&param.properties)?;
418        let is_append_only = param.sink_type.is_append_only();
419        let pk_indices = param.downstream_pk_or_empty();
420        Ok(Self {
421            config,
422            schema,
423            pk_indices,
424            is_append_only,
425            param,
426        })
427    }
428}
429
430impl Sink for SnowflakeV2Sink {
431    type LogSinker = CoordinatedLogSinker<SnowflakeSinkWriter>;
432
433    const SINK_NAME: &'static str = SNOWFLAKE_SINK_V2;
434
435    async fn validate(&self) -> Result<()> {
436        risingwave_common::license::Feature::SnowflakeSink
437            .check_available()
438            .map_err(|e| anyhow::anyhow!(e))?;
439        if let Some((snowflake_task_ctx, client)) =
440            self.config.build_snowflake_task_ctx_jdbc_client(
441                self.is_append_only,
442                &self.schema,
443                &self.pk_indices,
444            )?
445        {
446            let client = SnowflakeJniClient::new(client, snowflake_task_ctx);
447            client.execute_create_table().await?;
448        }
449
450        Ok(())
451    }
452
453    fn support_schema_change() -> bool {
454        true
455    }
456
457    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
458        SnowflakeV2Config::from_btreemap(config)?;
459        Ok(())
460    }
461
462    async fn new_log_sinker(
463        &self,
464        writer_param: crate::sink::SinkWriterParam,
465    ) -> Result<Self::LogSinker> {
466        let writer = SnowflakeSinkWriter::new(
467            self.config.clone(),
468            self.is_append_only,
469            writer_param.clone(),
470            self.param.clone(),
471        )
472        .await?;
473
474        let commit_checkpoint_interval =
475            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
476                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
477            );
478
479        CoordinatedLogSinker::new(
480            &writer_param,
481            self.param.clone(),
482            writer,
483            commit_checkpoint_interval,
484        )
485        .await
486    }
487
488    fn is_coordinated_sink(&self) -> bool {
489        true
490    }
491
492    async fn new_coordinator(
493        &self,
494        _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
495    ) -> Result<SinkCommitCoordinator> {
496        let coordinator = SnowflakeSinkCommitter::new(
497            self.config.clone(),
498            &self.schema,
499            &self.pk_indices,
500            self.is_append_only,
501        )?;
502        Ok(SinkCommitCoordinator::SinglePhase(Box::new(coordinator)))
503    }
504}
505
506pub enum SnowflakeSinkWriter {
507    S3(SnowflakeRedshiftSinkS3Writer),
508    Jdbc(SnowflakeSinkJdbcWriter),
509}
510
511impl SnowflakeSinkWriter {
512    pub async fn new(
513        config: SnowflakeV2Config,
514        is_append_only: bool,
515        writer_param: SinkWriterParam,
516        param: SinkParam,
517    ) -> Result<Self> {
518        let schema = param.schema();
519        if config.with_s3 {
520            let executor_id = writer_param.executor_id;
521            let s3_writer = SnowflakeRedshiftSinkS3Writer::new(
522                config.s3_inner.ok_or_else(|| {
523                    SinkError::Config(anyhow!(
524                        "S3 configuration is required for Snowflake S3 sink"
525                    ))
526                })?,
527                schema,
528                is_append_only,
529                executor_id,
530                config.snowflake_target_table_name,
531            )?;
532            Ok(Self::S3(s3_writer))
533        } else {
534            let jdbc_writer =
535                SnowflakeSinkJdbcWriter::new(config, is_append_only, writer_param, param).await?;
536            Ok(Self::Jdbc(jdbc_writer))
537        }
538    }
539}
540
541#[async_trait]
542impl SinkWriter for SnowflakeSinkWriter {
543    type CommitMetadata = Option<SinkMetadata>;
544
545    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
546        match self {
547            Self::S3(writer) => writer.begin_epoch(epoch),
548            Self::Jdbc(writer) => writer.begin_epoch(epoch).await,
549        }
550    }
551
552    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
553        match self {
554            Self::S3(writer) => writer.write_batch(chunk).await,
555            Self::Jdbc(writer) => writer.write_batch(chunk).await,
556        }
557    }
558
559    async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
560        match self {
561            Self::S3(writer) => {
562                writer.barrier(is_checkpoint).await?;
563            }
564            Self::Jdbc(writer) => {
565                writer.barrier(is_checkpoint).await?;
566            }
567        }
568        Ok(Some(SinkMetadata {
569            metadata: Some(sink_metadata::Metadata::Serialized(
570                risingwave_pb::connector_service::sink_metadata::SerializedMetadata {
571                    metadata: vec![],
572                },
573            )),
574        }))
575    }
576
577    async fn abort(&mut self) -> Result<()> {
578        if let Self::Jdbc(writer) = self {
579            writer.abort().await
580        } else {
581            Ok(())
582        }
583    }
584}
585
586pub struct SnowflakeSinkJdbcWriter {
587    augmented_row: AugmentedChunk,
588    jdbc_sink_writer: CoordinatedRemoteSinkWriter,
589}
590
591impl SnowflakeSinkJdbcWriter {
592    pub async fn new(
593        config: SnowflakeV2Config,
594        is_append_only: bool,
595        writer_param: SinkWriterParam,
596        mut param: SinkParam,
597    ) -> Result<Self> {
598        let metrics = SinkWriterMetrics::new(&writer_param);
599        let properties = &param.properties;
600        let column_descs = &mut param.columns;
601        let full_table_name = if is_append_only {
602            format!(
603                r#""{}"."{}"."{}""#,
604                config.snowflake_database.clone().unwrap_or_default(),
605                config.snowflake_schema.clone().unwrap_or_default(),
606                config
607                    .snowflake_target_table_name
608                    .clone()
609                    .unwrap_or_default()
610            )
611        } else {
612            let max_column_id = column_descs
613                .iter()
614                .map(|column| column.column_id.get_id())
615                .max()
616                .unwrap_or(0);
617            (*column_descs).push(ColumnDesc::named(
618                SNOWFLAKE_SINK_ROW_ID,
619                ColumnId::new(max_column_id + 1),
620                DataType::Varchar,
621            ));
622            (*column_descs).push(ColumnDesc::named(
623                SNOWFLAKE_SINK_OP,
624                ColumnId::new(max_column_id + 2),
625                DataType::Int32,
626            ));
627            format!(
628                r#""{}"."{}"."{}""#,
629                config.snowflake_database.clone().unwrap_or_default(),
630                config.snowflake_schema.clone().unwrap_or_default(),
631                config.snowflake_cdc_table_name.clone().unwrap_or_default()
632            )
633        };
634        let mut new_properties = BTreeMap::from([
635            ("table.name".to_owned(), full_table_name),
636            ("connector".to_owned(), "snowflake_v2".to_owned()),
637            (
638                "jdbc.url".to_owned(),
639                config.jdbc_url.clone().unwrap_or_default(),
640            ),
641            ("type".to_owned(), "append-only".to_owned()),
642            (
643                "primary_key".to_owned(),
644                properties.get("primary_key").cloned().unwrap_or_default(),
645            ),
646            (
647                "schema.name".to_owned(),
648                config.snowflake_schema.clone().unwrap_or_default(),
649            ),
650            (
651                "database.name".to_owned(),
652                config.snowflake_database.clone().unwrap_or_default(),
653            ),
654        ]);
655
656        // Reuse build_jdbc_connection_properties to get driver properties (auth, user, etc.)
657        let (_jdbc_url, connection_properties) = config.build_jdbc_connection_properties()?;
658        for (key, value) in connection_properties {
659            new_properties.insert(key, value);
660        }
661
662        param.properties = new_properties;
663
664        let jdbc_sink_writer =
665            CoordinatedRemoteSinkWriter::new(param.clone(), metrics.clone()).await?;
666        Ok(Self {
667            augmented_row: AugmentedChunk::new(0, is_append_only),
668            jdbc_sink_writer,
669        })
670    }
671}
672
673impl SnowflakeSinkJdbcWriter {
674    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
675        self.augmented_row.reset_epoch(epoch);
676        self.jdbc_sink_writer.begin_epoch(epoch).await?;
677        Ok(())
678    }
679
680    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
681        let chunk = self.augmented_row.augmented_chunk(chunk)?;
682        self.jdbc_sink_writer.write_batch(chunk).await?;
683        Ok(())
684    }
685
686    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
687        self.jdbc_sink_writer.barrier(is_checkpoint).await?;
688        Ok(())
689    }
690
691    async fn abort(&mut self) -> Result<()> {
692        // TODO: abort should clean up all the data written in this epoch.
693        self.jdbc_sink_writer.abort().await?;
694        Ok(())
695    }
696}
697
698#[derive(Default)]
699pub struct SnowflakeTaskContext {
700    // required for task creation
701    pub target_table_name: String,
702    pub database: String,
703    pub schema_name: String,
704    pub schema: Schema,
705
706    // only upsert
707    pub task_name: Option<String>,
708    pub cdc_table_name: Option<String>,
709    pub schedule_seconds: u64,
710    pub warehouse: Option<String>,
711    pub pk_column_names: Option<Vec<String>>,
712    pub all_column_names: Option<Vec<String>>,
713
714    // only s3 writer
715    pub stage: Option<String>,
716    pub pipe_name: Option<String>,
717}
718pub struct SnowflakeSinkCommitter {
719    client: Option<SnowflakeJniClient>,
720}
721
722impl SnowflakeSinkCommitter {
723    pub fn new(
724        config: SnowflakeV2Config,
725        schema: &Schema,
726        pk_indices: &Vec<usize>,
727        is_append_only: bool,
728    ) -> Result<Self> {
729        let client = if let Some((snowflake_task_ctx, client)) =
730            config.build_snowflake_task_ctx_jdbc_client(is_append_only, schema, pk_indices)?
731        {
732            Some(SnowflakeJniClient::new(client, snowflake_task_ctx))
733        } else {
734            None
735        };
736        Ok(Self { client })
737    }
738}
739
740#[async_trait]
741impl SinglePhaseCommitCoordinator for SnowflakeSinkCommitter {
742    async fn init(&mut self) -> Result<()> {
743        if let Some(client) = &self.client {
744            // Todo: move this to validate
745            client.execute_create_pipe().await?;
746            client.execute_create_merge_into_task().await?;
747        }
748        Ok(())
749    }
750
751    async fn commit(
752        &mut self,
753        _epoch: u64,
754        _metadata: Vec<SinkMetadata>,
755        add_columns: Option<Vec<Field>>,
756    ) -> Result<()> {
757        let client = self.client.as_mut().ok_or_else(|| {
758            SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
759        })?;
760        client.execute_flush_pipe().await?;
761
762        if let Some(add_columns) = add_columns {
763            client
764                .execute_alter_add_columns(
765                    &add_columns
766                        .iter()
767                        .map(|f| (f.name.clone(), f.data_type.to_string()))
768                        .collect::<Vec<_>>(),
769                )
770                .await?;
771        }
772        Ok(())
773    }
774}
775
776impl Drop for SnowflakeSinkCommitter {
777    fn drop(&mut self) {
778        if let Some(client) = self.client.take() {
779            tokio::spawn(async move {
780                client.execute_drop_task().await.ok();
781            });
782        }
783    }
784}
785
786pub struct SnowflakeJniClient {
787    jdbc_client: JdbcJniClient,
788    snowflake_task_context: SnowflakeTaskContext,
789}
790
791impl SnowflakeJniClient {
792    pub fn new(jdbc_client: JdbcJniClient, snowflake_task_context: SnowflakeTaskContext) -> Self {
793        Self {
794            jdbc_client,
795            snowflake_task_context,
796        }
797    }
798
799    pub async fn execute_alter_add_columns(
800        &mut self,
801        columns: &Vec<(String, String)>,
802    ) -> Result<()> {
803        self.execute_drop_task().await?;
804        if let Some(names) = self.snowflake_task_context.all_column_names.as_mut() {
805            names.extend(columns.iter().map(|(name, _)| name.clone()));
806        }
807        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
808            let alter_add_column_cdc_table_sql = build_alter_add_column_sql(
809                cdc_table_name,
810                &self.snowflake_task_context.database,
811                &self.snowflake_task_context.schema_name,
812                columns,
813            );
814            self.jdbc_client
815                .execute_sql_sync(vec![alter_add_column_cdc_table_sql])
816                .await?;
817        }
818
819        let alter_add_column_target_table_sql = build_alter_add_column_sql(
820            &self.snowflake_task_context.target_table_name,
821            &self.snowflake_task_context.database,
822            &self.snowflake_task_context.schema_name,
823            columns,
824        );
825        self.jdbc_client
826            .execute_sql_sync(vec![alter_add_column_target_table_sql])
827            .await?;
828
829        self.execute_create_merge_into_task().await?;
830        Ok(())
831    }
832
833    pub async fn execute_create_merge_into_task(&self) -> Result<()> {
834        if self.snowflake_task_context.task_name.is_some() {
835            let create_task_sql = build_create_merge_into_task_sql(&self.snowflake_task_context);
836            let start_task_sql = build_start_task_sql(&self.snowflake_task_context);
837            self.jdbc_client
838                .execute_sql_sync(vec![create_task_sql])
839                .await?;
840            self.jdbc_client
841                .execute_sql_sync(vec![start_task_sql])
842                .await?;
843        }
844        Ok(())
845    }
846
847    pub async fn execute_drop_task(&self) -> Result<()> {
848        if self.snowflake_task_context.task_name.is_some() {
849            let sql = build_drop_task_sql(&self.snowflake_task_context);
850            if let Err(e) = self.jdbc_client.execute_sql_sync(vec![sql]).await {
851                tracing::error!(
852                    "Failed to drop Snowflake sink task {:?}: {:?}",
853                    self.snowflake_task_context.task_name,
854                    e.as_report()
855                );
856            } else {
857                tracing::info!(
858                    "Snowflake sink task {:?} dropped",
859                    self.snowflake_task_context.task_name
860                );
861            }
862        }
863        Ok(())
864    }
865
866    pub async fn execute_create_table(&self) -> Result<()> {
867        // create target table
868        let create_target_table_sql = build_create_table_sql(
869            &self.snowflake_task_context.target_table_name,
870            &self.snowflake_task_context.database,
871            &self.snowflake_task_context.schema_name,
872            &self.snowflake_task_context.schema,
873            false,
874        )?;
875        self.jdbc_client
876            .execute_sql_sync(vec![create_target_table_sql])
877            .await?;
878        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
879            let create_cdc_table_sql = build_create_table_sql(
880                cdc_table_name,
881                &self.snowflake_task_context.database,
882                &self.snowflake_task_context.schema_name,
883                &self.snowflake_task_context.schema,
884                true,
885            )?;
886            self.jdbc_client
887                .execute_sql_sync(vec![create_cdc_table_sql])
888                .await?;
889        }
890        Ok(())
891    }
892
893    pub async fn execute_create_pipe(&self) -> Result<()> {
894        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
895            let table_name =
896                if let Some(table_name) = self.snowflake_task_context.cdc_table_name.as_ref() {
897                    table_name
898                } else {
899                    &self.snowflake_task_context.target_table_name
900                };
901            let create_pipe_sql = build_create_pipe_sql(
902                table_name,
903                &self.snowflake_task_context.database,
904                &self.snowflake_task_context.schema_name,
905                self.snowflake_task_context.stage.as_ref().ok_or_else(|| {
906                    SinkError::Config(anyhow!("snowflake.stage is required for S3 writer"))
907                })?,
908                pipe_name,
909                &self.snowflake_task_context.target_table_name,
910            );
911            self.jdbc_client
912                .execute_sql_sync(vec![create_pipe_sql])
913                .await?;
914        }
915        Ok(())
916    }
917
918    pub async fn execute_flush_pipe(&self) -> Result<()> {
919        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
920            let flush_pipe_sql = build_flush_pipe_sql(
921                &self.snowflake_task_context.database,
922                &self.snowflake_task_context.schema_name,
923                pipe_name,
924            );
925            self.jdbc_client
926                .execute_sql_sync(vec![flush_pipe_sql])
927                .await?;
928        }
929        Ok(())
930    }
931}
932
933fn build_create_table_sql(
934    table_name: &str,
935    database: &str,
936    schema_name: &str,
937    schema: &Schema,
938    need_op_and_row_id: bool,
939) -> Result<String> {
940    let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema_name, table_name);
941    let mut columns: Vec<String> = schema
942        .fields
943        .iter()
944        .map(|field| {
945            let data_type = convert_snowflake_data_type(&field.data_type)?;
946            Ok(format!(r#""{}" {}"#, field.name, data_type))
947        })
948        .collect::<Result<Vec<String>>>()?;
949    if need_op_and_row_id {
950        columns.push(format!(r#""{}" STRING"#, SNOWFLAKE_SINK_ROW_ID));
951        columns.push(format!(r#""{}" INT"#, SNOWFLAKE_SINK_OP));
952    }
953    let columns_str = columns.join(", ");
954    Ok(format!(
955        "CREATE TABLE IF NOT EXISTS {} ({}) ENABLE_SCHEMA_EVOLUTION  = true",
956        full_table_name, columns_str
957    ))
958}
959
960fn convert_snowflake_data_type(data_type: &DataType) -> Result<String> {
961    let data_type = match data_type {
962        DataType::Int16 => "SMALLINT".to_owned(),
963        DataType::Int32 => "INTEGER".to_owned(),
964        DataType::Int64 => "BIGINT".to_owned(),
965        DataType::Float32 => "FLOAT4".to_owned(),
966        DataType::Float64 => "FLOAT8".to_owned(),
967        DataType::Boolean => "BOOLEAN".to_owned(),
968        DataType::Varchar => "STRING".to_owned(),
969        DataType::Date => "DATE".to_owned(),
970        DataType::Timestamp => "TIMESTAMP".to_owned(),
971        DataType::Timestamptz => "TIMESTAMP_TZ".to_owned(),
972        DataType::Jsonb => "STRING".to_owned(),
973        DataType::Decimal => "DECIMAL".to_owned(),
974        DataType::Bytea => "BINARY".to_owned(),
975        DataType::Time => "TIME".to_owned(),
976        _ => {
977            return Err(SinkError::Config(anyhow!(
978                "Dont support auto create table for datatype: {}",
979                data_type
980            )));
981        }
982    };
983    Ok(data_type)
984}
985
986fn build_create_pipe_sql(
987    table_name: &str,
988    database: &str,
989    schema: &str,
990    stage: &str,
991    pipe_name: &str,
992    target_table_name: &str,
993) -> String {
994    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
995    let stage = format!(
996        r#""{}"."{}"."{}"/{}"#,
997        database, schema, stage, target_table_name
998    );
999    let table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
1000    format!(
1001        "CREATE OR REPLACE PIPE {} AUTO_INGEST = FALSE AS COPY INTO {} FROM @{} MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE FILE_FORMAT = (type = 'JSON');",
1002        pipe_name, table_name, stage
1003    )
1004}
1005
1006fn build_flush_pipe_sql(database: &str, schema: &str, pipe_name: &str) -> String {
1007    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
1008    format!("ALTER PIPE {} REFRESH;", pipe_name,)
1009}
1010
1011fn build_alter_add_column_sql(
1012    table_name: &str,
1013    database: &str,
1014    schema: &str,
1015    columns: &Vec<(String, String)>,
1016) -> String {
1017    let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
1018    jdbc_jni_client::build_alter_add_column_sql(&full_table_name, columns, true)
1019}
1020
1021fn build_start_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1022    let SnowflakeTaskContext {
1023        task_name,
1024        database,
1025        schema_name: schema,
1026        ..
1027    } = snowflake_task_context;
1028    let full_task_name = format!(
1029        r#""{}"."{}"."{}""#,
1030        database,
1031        schema,
1032        task_name.as_ref().unwrap()
1033    );
1034    format!("ALTER TASK {} RESUME", full_task_name)
1035}
1036
1037fn build_drop_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1038    let SnowflakeTaskContext {
1039        task_name,
1040        database,
1041        schema_name: schema,
1042        ..
1043    } = snowflake_task_context;
1044    let full_task_name = format!(
1045        r#""{}"."{}"."{}""#,
1046        database,
1047        schema,
1048        task_name.as_ref().unwrap()
1049    );
1050    format!("DROP TASK IF EXISTS {}", full_task_name)
1051}
1052
1053fn build_create_merge_into_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1054    let SnowflakeTaskContext {
1055        task_name,
1056        cdc_table_name,
1057        target_table_name,
1058        schedule_seconds,
1059        warehouse,
1060        pk_column_names,
1061        all_column_names,
1062        database,
1063        schema_name,
1064        ..
1065    } = snowflake_task_context;
1066    let full_task_name = format!(
1067        r#""{}"."{}"."{}""#,
1068        database,
1069        schema_name,
1070        task_name.as_ref().unwrap()
1071    );
1072    let full_cdc_table_name = format!(
1073        r#""{}"."{}"."{}""#,
1074        database,
1075        schema_name,
1076        cdc_table_name.as_ref().unwrap()
1077    );
1078    let full_target_table_name = format!(
1079        r#""{}"."{}"."{}""#,
1080        database, schema_name, target_table_name
1081    );
1082
1083    let pk_names_str = pk_column_names
1084        .as_ref()
1085        .unwrap()
1086        .iter()
1087        .map(|name| format!(r#""{}""#, name))
1088        .collect::<Vec<String>>()
1089        .join(", ");
1090    let pk_names_eq_str = pk_column_names
1091        .as_ref()
1092        .unwrap()
1093        .iter()
1094        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1095        .collect::<Vec<String>>()
1096        .join(" AND ");
1097    let all_column_names_set_str = all_column_names
1098        .as_ref()
1099        .unwrap()
1100        .iter()
1101        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1102        .collect::<Vec<String>>()
1103        .join(", ");
1104    let all_column_names_str = all_column_names
1105        .as_ref()
1106        .unwrap()
1107        .iter()
1108        .map(|name| format!(r#""{}""#, name))
1109        .collect::<Vec<String>>()
1110        .join(", ");
1111    let all_column_names_insert_str = all_column_names
1112        .as_ref()
1113        .unwrap()
1114        .iter()
1115        .map(|name| format!(r#"source."{}""#, name))
1116        .collect::<Vec<String>>()
1117        .join(", ");
1118
1119    format!(
1120        r#"CREATE OR REPLACE TASK {task_name}
1121WAREHOUSE = {warehouse}
1122SCHEDULE = '{schedule_seconds} SECONDS'
1123AS
1124BEGIN
1125    LET max_row_id STRING;
1126
1127    SELECT COALESCE(MAX("{snowflake_sink_row_id}"), '0') INTO :max_row_id
1128    FROM {cdc_table_name};
1129
1130    MERGE INTO {target_table_name} AS target
1131    USING (
1132        SELECT *
1133        FROM (
1134            SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_names_str} ORDER BY "{snowflake_sink_row_id}" DESC) AS dedupe_id
1135            FROM {cdc_table_name}
1136            WHERE "{snowflake_sink_row_id}" <= :max_row_id
1137        ) AS subquery
1138        WHERE dedupe_id = 1
1139    ) AS source
1140    ON {pk_names_eq_str}
1141    WHEN MATCHED AND source."{snowflake_sink_op}" IN (2, 4) THEN DELETE
1142    WHEN MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN UPDATE SET {all_column_names_set_str}
1143    WHEN NOT MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN INSERT ({all_column_names_str}) VALUES ({all_column_names_insert_str});
1144
1145    DELETE FROM {cdc_table_name}
1146    WHERE "{snowflake_sink_row_id}" <= :max_row_id;
1147END;"#,
1148        task_name = full_task_name,
1149        warehouse = warehouse.as_ref().unwrap(),
1150        schedule_seconds = schedule_seconds,
1151        cdc_table_name = full_cdc_table_name,
1152        target_table_name = full_target_table_name,
1153        pk_names_str = pk_names_str,
1154        pk_names_eq_str = pk_names_eq_str,
1155        all_column_names_set_str = all_column_names_set_str,
1156        all_column_names_str = all_column_names_str,
1157        all_column_names_insert_str = all_column_names_insert_str,
1158        snowflake_sink_row_id = SNOWFLAKE_SINK_ROW_ID,
1159        snowflake_sink_op = SNOWFLAKE_SINK_OP,
1160    )
1161}
1162
1163#[cfg(test)]
1164mod tests {
1165    use std::collections::BTreeMap;
1166
1167    use super::*;
1168    use crate::sink::jdbc_jni_client::normalize_sql;
1169
1170    fn base_properties() -> BTreeMap<String, String> {
1171        BTreeMap::from([
1172            ("type".to_owned(), "append-only".to_owned()),
1173            ("jdbc.url".to_owned(), "jdbc:snowflake://account".to_owned()),
1174            ("username".to_owned(), "RW_USER".to_owned()),
1175        ])
1176    }
1177
1178    #[test]
1179    fn test_build_jdbc_props_password() {
1180        let mut props = base_properties();
1181        props.insert("password".to_owned(), "secret".to_owned());
1182        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1183        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1184        assert_eq!(url, "jdbc:snowflake://account");
1185        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1186        assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1187        assert_eq!(map.get("password"), Some(&"secret".to_owned()));
1188        assert!(!map.contains_key("authenticator"));
1189    }
1190
1191    #[test]
1192    fn test_build_jdbc_props_key_pair_file() {
1193        let mut props = base_properties();
1194        props.insert(
1195            "auth.method".to_owned(),
1196            AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
1197        );
1198        props.insert("private_key_file".to_owned(), "/tmp/rsa_key.p8".to_owned());
1199        props.insert("private_key_file_pwd".to_owned(), "dummy".to_owned());
1200        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1201        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1202        assert_eq!(url, "jdbc:snowflake://account");
1203        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1204        assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1205        assert_eq!(
1206            map.get("private_key_file"),
1207            Some(&"/tmp/rsa_key.p8".to_owned())
1208        );
1209        assert_eq!(map.get("private_key_file_pwd"), Some(&"dummy".to_owned()));
1210    }
1211
1212    #[test]
1213    fn test_build_jdbc_props_key_pair_object() {
1214        let mut props = base_properties();
1215        props.insert(
1216            "auth.method".to_owned(),
1217            AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
1218        );
1219        props.insert(
1220            "private_key_pem".to_owned(),
1221            "-----BEGIN PRIVATE KEY-----
1222...
1223-----END PRIVATE KEY-----"
1224                .to_owned(),
1225        );
1226        let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1227        let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1228        assert_eq!(url, "jdbc:snowflake://account");
1229        let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1230        assert_eq!(
1231            map.get("private_key_pem"),
1232            Some(
1233                &"-----BEGIN PRIVATE KEY-----
1234...
1235-----END PRIVATE KEY-----"
1236                    .to_owned()
1237            )
1238        );
1239        assert!(!map.contains_key("private_key_file"));
1240    }
1241
1242    #[test]
1243    fn test_snowflake_sink_commit_coordinator() {
1244        let snowflake_task_context = SnowflakeTaskContext {
1245            task_name: Some("test_task".to_owned()),
1246            cdc_table_name: Some("test_cdc_table".to_owned()),
1247            target_table_name: "test_target_table".to_owned(),
1248            schedule_seconds: 3600,
1249            warehouse: Some("test_warehouse".to_owned()),
1250            pk_column_names: Some(vec!["v1".to_owned()]),
1251            all_column_names: Some(vec!["v1".to_owned(), "v2".to_owned()]),
1252            database: "test_db".to_owned(),
1253            schema_name: "test_schema".to_owned(),
1254            schema: Schema { fields: vec![] },
1255            stage: None,
1256            pipe_name: None,
1257        };
1258        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1259        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task"
1260WAREHOUSE = test_warehouse
1261SCHEDULE = '3600 SECONDS'
1262AS
1263BEGIN
1264    LET max_row_id STRING;
1265
1266    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1267    FROM "test_db"."test_schema"."test_cdc_table";
1268
1269    MERGE INTO "test_db"."test_schema"."test_target_table" AS target
1270    USING (
1271        SELECT *
1272        FROM (
1273            SELECT *, ROW_NUMBER() OVER (PARTITION BY "v1" ORDER BY "__row_id" DESC) AS dedupe_id
1274            FROM "test_db"."test_schema"."test_cdc_table"
1275            WHERE "__row_id" <= :max_row_id
1276        ) AS subquery
1277        WHERE dedupe_id = 1
1278    ) AS source
1279    ON target."v1" = source."v1"
1280    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1281    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."v1" = source."v1", target."v2" = source."v2"
1282    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("v1", "v2") VALUES (source."v1", source."v2");
1283
1284    DELETE FROM "test_db"."test_schema"."test_cdc_table"
1285    WHERE "__row_id" <= :max_row_id;
1286END;"#;
1287        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1288    }
1289
1290    #[test]
1291    fn test_snowflake_sink_commit_coordinator_multi_pk() {
1292        let snowflake_task_context = SnowflakeTaskContext {
1293            task_name: Some("test_task_multi_pk".to_owned()),
1294            cdc_table_name: Some("cdc_multi_pk".to_owned()),
1295            target_table_name: "target_multi_pk".to_owned(),
1296            schedule_seconds: 300,
1297            warehouse: Some("multi_pk_warehouse".to_owned()),
1298            pk_column_names: Some(vec!["id1".to_owned(), "id2".to_owned()]),
1299            all_column_names: Some(vec!["id1".to_owned(), "id2".to_owned(), "val".to_owned()]),
1300            database: "test_db".to_owned(),
1301            schema_name: "test_schema".to_owned(),
1302            schema: Schema { fields: vec![] },
1303            stage: None,
1304            pipe_name: None,
1305        };
1306        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1307        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task_multi_pk"
1308WAREHOUSE = multi_pk_warehouse
1309SCHEDULE = '300 SECONDS'
1310AS
1311BEGIN
1312    LET max_row_id STRING;
1313
1314    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1315    FROM "test_db"."test_schema"."cdc_multi_pk";
1316
1317    MERGE INTO "test_db"."test_schema"."target_multi_pk" AS target
1318    USING (
1319        SELECT *
1320        FROM (
1321            SELECT *, ROW_NUMBER() OVER (PARTITION BY "id1", "id2" ORDER BY "__row_id" DESC) AS dedupe_id
1322            FROM "test_db"."test_schema"."cdc_multi_pk"
1323            WHERE "__row_id" <= :max_row_id
1324        ) AS subquery
1325        WHERE dedupe_id = 1
1326    ) AS source
1327    ON target."id1" = source."id1" AND target."id2" = source."id2"
1328    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1329    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."id1" = source."id1", target."id2" = source."id2", target."val" = source."val"
1330    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("id1", "id2", "val") VALUES (source."id1", source."id2", source."val");
1331
1332    DELETE FROM "test_db"."test_schema"."cdc_multi_pk"
1333    WHERE "__row_id" <= :max_row_id;
1334END;"#;
1335        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1336    }
1337}