risingwave_connector/sink/snowflake_redshift/
snowflake.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::num::NonZeroU64;
16use std::collections::BTreeMap;
17
18use anyhow::anyhow;
19use phf::{Set, phf_set};
20use risingwave_common::array::StreamChunk;
21use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
22use risingwave_common::types::DataType;
23use risingwave_pb::connector_service::{SinkMetadata, sink_metadata};
24use sea_orm::DatabaseConnection;
25use serde::Deserialize;
26use serde_with::{DisplayFromStr, serde_as};
27use thiserror_ext::AsReport;
28use tokio::sync::mpsc::UnboundedSender;
29use tonic::async_trait;
30use with_options::WithOptions;
31
32use crate::connector_common::IcebergSinkCompactionUpdate;
33use crate::enforce_secret::EnforceSecret;
34use crate::sink::coordinate::CoordinatedLogSinker;
35use crate::sink::decouple_checkpoint_log_sink::default_commit_checkpoint_interval;
36use crate::sink::file_sink::s3::S3Common;
37use crate::sink::jdbc_jni_client::{self, JdbcJniClient};
38use crate::sink::remote::CoordinatedRemoteSinkWriter;
39use crate::sink::snowflake_redshift::{AugmentedChunk, SnowflakeRedshiftSinkS3Writer};
40use crate::sink::writer::SinkWriter;
41use crate::sink::{
42    Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, Sink, SinkCommitCoordinator,
43    SinkCommittedEpochSubscriber, SinkError, SinkParam, SinkWriterMetrics, SinkWriterParam,
44};
45
46pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2";
47pub const SNOWFLAKE_SINK_ROW_ID: &str = "__row_id";
48pub const SNOWFLAKE_SINK_OP: &str = "__op";
49
50#[serde_as]
51#[derive(Debug, Clone, Deserialize, WithOptions)]
52pub struct SnowflakeV2Config {
53    #[serde(rename = "type")]
54    pub r#type: String,
55
56    #[serde(rename = "intermediate.table.name")]
57    pub snowflake_cdc_table_name: Option<String>,
58
59    #[serde(rename = "table.name")]
60    pub snowflake_target_table_name: Option<String>,
61
62    #[serde(rename = "database")]
63    pub snowflake_database: Option<String>,
64
65    #[serde(rename = "schema")]
66    pub snowflake_schema: Option<String>,
67
68    #[serde(default = "default_schedule")]
69    #[serde(rename = "write.target.interval.seconds")]
70    #[serde_as(as = "DisplayFromStr")]
71    pub snowflake_schedule_seconds: u64,
72
73    #[serde(rename = "warehouse")]
74    pub snowflake_warehouse: Option<String>,
75
76    #[serde(rename = "jdbc.url")]
77    pub jdbc_url: Option<String>,
78
79    #[serde(rename = "username")]
80    pub username: Option<String>,
81
82    #[serde(rename = "password")]
83    pub password: Option<String>,
84
85    /// Commit every n(>0) checkpoints, default is 10.
86    #[serde(default = "default_commit_checkpoint_interval")]
87    #[serde_as(as = "DisplayFromStr")]
88    #[with_option(allow_alter_on_fly)]
89    pub commit_checkpoint_interval: u64,
90
91    /// Enable auto schema change for upsert sink.
92    /// If enabled, the sink will automatically alter the target table to add new columns.
93    #[serde(default)]
94    #[serde(rename = "auto.schema.change")]
95    #[serde_as(as = "DisplayFromStr")]
96    pub auto_schema_change: bool,
97
98    #[serde(default)]
99    #[serde(rename = "create_table_if_not_exists")]
100    #[serde_as(as = "DisplayFromStr")]
101    pub create_table_if_not_exists: bool,
102
103    #[serde(default = "default_with_s3")]
104    #[serde(rename = "with_s3")]
105    #[serde_as(as = "DisplayFromStr")]
106    pub with_s3: bool,
107
108    #[serde(flatten)]
109    pub s3_inner: Option<S3Common>,
110
111    #[serde(rename = "stage")]
112    pub stage: Option<String>,
113}
114
115fn default_schedule() -> u64 {
116    3600 // Default to 1 hour
117}
118
119fn default_with_s3() -> bool {
120    true
121}
122
123impl SnowflakeV2Config {
124    pub fn from_btreemap(properties: &BTreeMap<String, String>) -> Result<Self> {
125        let config =
126            serde_json::from_value::<SnowflakeV2Config>(serde_json::to_value(properties).unwrap())
127                .map_err(|e| SinkError::Config(anyhow!(e)))?;
128        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
129            return Err(SinkError::Config(anyhow!(
130                "`{}` must be {}, or {}",
131                SINK_TYPE_OPTION,
132                SINK_TYPE_APPEND_ONLY,
133                SINK_TYPE_UPSERT
134            )));
135        }
136        Ok(config)
137    }
138
139    pub fn build_snowflake_task_ctx_jdbc_client(
140        &self,
141        is_append_only: bool,
142        schema: &Schema,
143        pk_indices: &Vec<usize>,
144    ) -> Result<Option<(SnowflakeTaskContext, JdbcJniClient)>> {
145        if !self.auto_schema_change && is_append_only && !self.create_table_if_not_exists {
146            // append-only + no auto schema change is not need to create a client
147            return Ok(None);
148        }
149        let target_table_name = self
150            .snowflake_target_table_name
151            .clone()
152            .ok_or(SinkError::Config(anyhow!("table.name is required")))?;
153        let database = self
154            .snowflake_database
155            .clone()
156            .ok_or(SinkError::Config(anyhow!("database is required")))?
157            .to_owned();
158        let schema_name = self
159            .snowflake_schema
160            .clone()
161            .ok_or(SinkError::Config(anyhow!("schema is required")))?
162            .to_owned();
163        let mut snowflake_task_ctx = SnowflakeTaskContext {
164            target_table_name: target_table_name.clone(),
165            database,
166            schema_name,
167            schema: schema.clone(),
168            ..Default::default()
169        };
170
171        let jdbc_url = self
172            .jdbc_url
173            .clone()
174            .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?
175            .to_owned();
176        let username = self
177            .username
178            .clone()
179            .ok_or(SinkError::Config(anyhow!("username is required")))?;
180        let password = self
181            .password
182            .clone()
183            .ok_or(SinkError::Config(anyhow!("password is required")))?;
184        let jdbc_url = format!("{}?user={}&password={}", jdbc_url, username, password);
185        let client = JdbcJniClient::new(jdbc_url)?;
186
187        if self.with_s3 {
188            let stage = self
189                .stage
190                .clone()
191                .ok_or(SinkError::Config(anyhow!("stage is required")))?;
192            snowflake_task_ctx.stage = Some(stage);
193            snowflake_task_ctx.pipe_name = Some(format!("{}_pipe", target_table_name));
194        }
195        if !is_append_only {
196            let cdc_table_name = self
197                .snowflake_cdc_table_name
198                .clone()
199                .ok_or(SinkError::Config(anyhow!(
200                    "intermediate.table.name is required"
201                )))?;
202            snowflake_task_ctx.cdc_table_name = Some(cdc_table_name.clone());
203            snowflake_task_ctx.schedule_seconds = self.snowflake_schedule_seconds;
204            snowflake_task_ctx.warehouse = Some(
205                self.snowflake_warehouse
206                    .clone()
207                    .ok_or(SinkError::Config(anyhow!("warehouse is required")))?,
208            );
209            let pk_column_names: Vec<_> = schema
210                .fields
211                .iter()
212                .enumerate()
213                .filter(|(index, _)| pk_indices.contains(index))
214                .map(|(_, field)| field.name.clone())
215                .collect();
216            if pk_column_names.is_empty() {
217                return Err(SinkError::Config(anyhow!(
218                    "Primary key columns not found. Please set the `primary_key` column in the sink properties, or ensure that the sink contains the primary key columns from the upstream."
219                )));
220            }
221            snowflake_task_ctx.pk_column_names = Some(pk_column_names);
222            snowflake_task_ctx.all_column_names = Some(
223                schema
224                    .fields
225                    .iter()
226                    .map(|field| field.name.clone())
227                    .collect(),
228            );
229            snowflake_task_ctx.task_name = Some(format!(
230                "rw_snowflake_sink_from_{cdc_table_name}_to_{target_table_name}"
231            ));
232        }
233        Ok(Some((snowflake_task_ctx, client)))
234    }
235}
236
237impl EnforceSecret for SnowflakeV2Config {
238    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
239        "username",
240        "password",
241        "jdbc.url",
242    };
243}
244
245#[derive(Clone, Debug)]
246pub struct SnowflakeV2Sink {
247    config: SnowflakeV2Config,
248    schema: Schema,
249    pk_indices: Vec<usize>,
250    is_append_only: bool,
251    param: SinkParam,
252}
253
254impl EnforceSecret for SnowflakeV2Sink {
255    fn enforce_secret<'a>(
256        prop_iter: impl Iterator<Item = &'a str>,
257    ) -> crate::sink::ConnectorResult<()> {
258        for prop in prop_iter {
259            SnowflakeV2Config::enforce_one(prop)?;
260        }
261        Ok(())
262    }
263}
264
265impl TryFrom<SinkParam> for SnowflakeV2Sink {
266    type Error = SinkError;
267
268    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
269        let schema = param.schema();
270        let config = SnowflakeV2Config::from_btreemap(&param.properties)?;
271        let is_append_only = param.sink_type.is_append_only();
272        let pk_indices = param.downstream_pk.clone();
273        Ok(Self {
274            config,
275            schema,
276            pk_indices,
277            is_append_only,
278            param,
279        })
280    }
281}
282
283impl Sink for SnowflakeV2Sink {
284    type Coordinator = SnowflakeSinkCommitter;
285    type LogSinker = CoordinatedLogSinker<SnowflakeSinkWriter>;
286
287    const SINK_NAME: &'static str = SNOWFLAKE_SINK_V2;
288
289    async fn validate(&self) -> Result<()> {
290        risingwave_common::license::Feature::SnowflakeSink
291            .check_available()
292            .map_err(|e| anyhow::anyhow!(e))?;
293        if let Some((snowflake_task_ctx, client)) =
294            self.config.build_snowflake_task_ctx_jdbc_client(
295                self.is_append_only,
296                &self.schema,
297                &self.pk_indices,
298            )?
299        {
300            let client = SnowflakeJniClient::new(client, snowflake_task_ctx);
301            client.execute_create_table().await?;
302        }
303
304        Ok(())
305    }
306
307    fn support_schema_change() -> bool {
308        true
309    }
310
311    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
312        SnowflakeV2Config::from_btreemap(config)?;
313        Ok(())
314    }
315
316    async fn new_log_sinker(
317        &self,
318        writer_param: crate::sink::SinkWriterParam,
319    ) -> Result<Self::LogSinker> {
320        let writer = SnowflakeSinkWriter::new(
321            self.config.clone(),
322            self.is_append_only,
323            writer_param.clone(),
324            self.param.clone(),
325        )
326        .await?;
327
328        let commit_checkpoint_interval =
329            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
330                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
331            );
332
333        CoordinatedLogSinker::new(
334            &writer_param,
335            self.param.clone(),
336            writer,
337            commit_checkpoint_interval,
338        )
339        .await
340    }
341
342    fn is_coordinated_sink(&self) -> bool {
343        true
344    }
345
346    async fn new_coordinator(
347        &self,
348        _db: DatabaseConnection,
349        _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
350    ) -> Result<Self::Coordinator> {
351        let coordinator = SnowflakeSinkCommitter::new(
352            self.config.clone(),
353            &self.schema,
354            &self.pk_indices,
355            self.is_append_only,
356        )?;
357        Ok(coordinator)
358    }
359}
360
361pub enum SnowflakeSinkWriter {
362    S3(SnowflakeRedshiftSinkS3Writer),
363    Jdbc(SnowflakeSinkJdbcWriter),
364}
365
366impl SnowflakeSinkWriter {
367    pub async fn new(
368        config: SnowflakeV2Config,
369        is_append_only: bool,
370        writer_param: SinkWriterParam,
371        param: SinkParam,
372    ) -> Result<Self> {
373        let schema = param.schema();
374        if config.with_s3 {
375            let executor_id = writer_param.executor_id;
376            let s3_writer = SnowflakeRedshiftSinkS3Writer::new(
377                config.s3_inner.ok_or_else(|| {
378                    SinkError::Config(anyhow!(
379                        "S3 configuration is required for Snowflake S3 sink"
380                    ))
381                })?,
382                schema,
383                is_append_only,
384                executor_id,
385                config.snowflake_target_table_name,
386            )?;
387            Ok(Self::S3(s3_writer))
388        } else {
389            let jdbc_writer =
390                SnowflakeSinkJdbcWriter::new(config, is_append_only, writer_param, param).await?;
391            Ok(Self::Jdbc(jdbc_writer))
392        }
393    }
394}
395
396#[async_trait]
397impl SinkWriter for SnowflakeSinkWriter {
398    type CommitMetadata = Option<SinkMetadata>;
399
400    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
401        match self {
402            Self::S3(writer) => writer.begin_epoch(epoch),
403            Self::Jdbc(writer) => writer.begin_epoch(epoch).await,
404        }
405    }
406
407    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
408        match self {
409            Self::S3(writer) => writer.write_batch(chunk).await,
410            Self::Jdbc(writer) => writer.write_batch(chunk).await,
411        }
412    }
413
414    async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
415        match self {
416            Self::S3(writer) => {
417                writer.barrier(is_checkpoint).await?;
418            }
419            Self::Jdbc(writer) => {
420                writer.barrier(is_checkpoint).await?;
421            }
422        }
423        Ok(Some(SinkMetadata {
424            metadata: Some(sink_metadata::Metadata::Serialized(
425                risingwave_pb::connector_service::sink_metadata::SerializedMetadata {
426                    metadata: vec![],
427                },
428            )),
429        }))
430    }
431
432    async fn abort(&mut self) -> Result<()> {
433        if let Self::Jdbc(writer) = self {
434            writer.abort().await
435        } else {
436            Ok(())
437        }
438    }
439}
440
441pub struct SnowflakeSinkJdbcWriter {
442    augmented_row: AugmentedChunk,
443    jdbc_sink_writer: CoordinatedRemoteSinkWriter,
444}
445
446impl SnowflakeSinkJdbcWriter {
447    pub async fn new(
448        config: SnowflakeV2Config,
449        is_append_only: bool,
450        writer_param: SinkWriterParam,
451        mut param: SinkParam,
452    ) -> Result<Self> {
453        let metrics = SinkWriterMetrics::new(&writer_param);
454        let properties = &param.properties;
455        let column_descs = &mut param.columns;
456        let full_table_name = if is_append_only {
457            format!(
458                r#""{}"."{}"."{}""#,
459                config.snowflake_database.clone().unwrap_or_default(),
460                config.snowflake_schema.clone().unwrap_or_default(),
461                config
462                    .snowflake_target_table_name
463                    .clone()
464                    .unwrap_or_default()
465            )
466        } else {
467            let max_column_id = column_descs
468                .iter()
469                .map(|column| column.column_id.get_id())
470                .max()
471                .unwrap_or(0);
472            (*column_descs).push(ColumnDesc::named(
473                SNOWFLAKE_SINK_ROW_ID,
474                ColumnId::new(max_column_id + 1),
475                DataType::Varchar,
476            ));
477            (*column_descs).push(ColumnDesc::named(
478                SNOWFLAKE_SINK_OP,
479                ColumnId::new(max_column_id + 2),
480                DataType::Int32,
481            ));
482            format!(
483                r#""{}"."{}"."{}""#,
484                config.snowflake_database.clone().unwrap_or_default(),
485                config.snowflake_schema.clone().unwrap_or_default(),
486                config.snowflake_cdc_table_name.clone().unwrap_or_default()
487            )
488        };
489        let new_properties = BTreeMap::from([
490            ("table.name".to_owned(), full_table_name),
491            ("connector".to_owned(), "snowflake_v2".to_owned()),
492            (
493                "jdbc.url".to_owned(),
494                config.jdbc_url.clone().unwrap_or_default(),
495            ),
496            ("type".to_owned(), "append-only".to_owned()),
497            (
498                "user".to_owned(),
499                config.username.clone().unwrap_or_default(),
500            ),
501            (
502                "password".to_owned(),
503                config.password.clone().unwrap_or_default(),
504            ),
505            (
506                "primary_key".to_owned(),
507                properties.get("primary_key").cloned().unwrap_or_default(),
508            ),
509            (
510                "schema.name".to_owned(),
511                config.snowflake_schema.clone().unwrap_or_default(),
512            ),
513            (
514                "database.name".to_owned(),
515                config.snowflake_database.clone().unwrap_or_default(),
516            ),
517        ]);
518        param.properties = new_properties;
519
520        let jdbc_sink_writer =
521            CoordinatedRemoteSinkWriter::new(param.clone(), metrics.clone()).await?;
522        Ok(Self {
523            augmented_row: AugmentedChunk::new(0, is_append_only),
524            jdbc_sink_writer,
525        })
526    }
527}
528
529impl SnowflakeSinkJdbcWriter {
530    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
531        self.augmented_row.reset_epoch(epoch);
532        self.jdbc_sink_writer.begin_epoch(epoch).await?;
533        Ok(())
534    }
535
536    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
537        let chunk = self.augmented_row.augmented_chunk(chunk)?;
538        self.jdbc_sink_writer.write_batch(chunk).await?;
539        Ok(())
540    }
541
542    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
543        self.jdbc_sink_writer.barrier(is_checkpoint).await?;
544        Ok(())
545    }
546
547    async fn abort(&mut self) -> Result<()> {
548        // TODO: abort should clean up all the data written in this epoch.
549        self.jdbc_sink_writer.abort().await?;
550        Ok(())
551    }
552}
553
554#[derive(Default)]
555pub struct SnowflakeTaskContext {
556    // required for task creation
557    pub target_table_name: String,
558    pub database: String,
559    pub schema_name: String,
560    pub schema: Schema,
561
562    // only upsert
563    pub task_name: Option<String>,
564    pub cdc_table_name: Option<String>,
565    pub schedule_seconds: u64,
566    pub warehouse: Option<String>,
567    pub pk_column_names: Option<Vec<String>>,
568    pub all_column_names: Option<Vec<String>>,
569
570    // only s3 writer
571    pub stage: Option<String>,
572    pub pipe_name: Option<String>,
573}
574pub struct SnowflakeSinkCommitter {
575    client: Option<SnowflakeJniClient>,
576}
577
578impl SnowflakeSinkCommitter {
579    pub fn new(
580        config: SnowflakeV2Config,
581        schema: &Schema,
582        pk_indices: &Vec<usize>,
583        is_append_only: bool,
584    ) -> Result<Self> {
585        let client = if let Some((snowflake_task_ctx, client)) =
586            config.build_snowflake_task_ctx_jdbc_client(is_append_only, schema, pk_indices)?
587        {
588            Some(SnowflakeJniClient::new(client, snowflake_task_ctx))
589        } else {
590            None
591        };
592        Ok(Self { client })
593    }
594}
595
596#[async_trait]
597impl SinkCommitCoordinator for SnowflakeSinkCommitter {
598    async fn init(&mut self, _subscriber: SinkCommittedEpochSubscriber) -> Result<Option<u64>> {
599        if let Some(client) = &self.client {
600            // Todo: move this to validate
601            client.execute_create_pipe().await?;
602            client.execute_create_merge_into_task().await?;
603        }
604        Ok(None)
605    }
606
607    async fn commit(
608        &mut self,
609        _epoch: u64,
610        _metadata: Vec<SinkMetadata>,
611        add_columns: Option<Vec<Field>>,
612    ) -> Result<()> {
613        let client = self.client.as_mut().ok_or_else(|| {
614            SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
615        })?;
616        client.execute_flush_pipe().await?;
617
618        if let Some(add_columns) = add_columns {
619            client
620                .execute_alter_add_columns(
621                    &add_columns
622                        .iter()
623                        .map(|f| (f.name.clone(), f.data_type.to_string()))
624                        .collect::<Vec<_>>(),
625                )
626                .await?;
627        }
628        Ok(())
629    }
630}
631
632impl Drop for SnowflakeSinkCommitter {
633    fn drop(&mut self) {
634        if let Some(client) = self.client.take() {
635            tokio::spawn(async move {
636                client.execute_drop_task().await.ok();
637            });
638        }
639    }
640}
641
642pub struct SnowflakeJniClient {
643    jdbc_client: JdbcJniClient,
644    snowflake_task_context: SnowflakeTaskContext,
645}
646
647impl SnowflakeJniClient {
648    pub fn new(jdbc_client: JdbcJniClient, snowflake_task_context: SnowflakeTaskContext) -> Self {
649        Self {
650            jdbc_client,
651            snowflake_task_context,
652        }
653    }
654
655    pub async fn execute_alter_add_columns(
656        &mut self,
657        columns: &Vec<(String, String)>,
658    ) -> Result<()> {
659        self.execute_drop_task().await?;
660        if let Some(names) = self.snowflake_task_context.all_column_names.as_mut() {
661            names.extend(columns.iter().map(|(name, _)| name.clone()));
662        }
663        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
664            let alter_add_column_cdc_table_sql = build_alter_add_column_sql(
665                cdc_table_name,
666                &self.snowflake_task_context.database,
667                &self.snowflake_task_context.schema_name,
668                columns,
669            );
670            self.jdbc_client
671                .execute_sql_sync(vec![alter_add_column_cdc_table_sql])
672                .await?;
673        }
674
675        let alter_add_column_target_table_sql = build_alter_add_column_sql(
676            &self.snowflake_task_context.target_table_name,
677            &self.snowflake_task_context.database,
678            &self.snowflake_task_context.schema_name,
679            columns,
680        );
681        self.jdbc_client
682            .execute_sql_sync(vec![alter_add_column_target_table_sql])
683            .await?;
684
685        self.execute_create_merge_into_task().await?;
686        Ok(())
687    }
688
689    pub async fn execute_create_merge_into_task(&self) -> Result<()> {
690        if self.snowflake_task_context.task_name.is_some() {
691            let create_task_sql = build_create_merge_into_task_sql(&self.snowflake_task_context);
692            let start_task_sql = build_start_task_sql(&self.snowflake_task_context);
693            self.jdbc_client
694                .execute_sql_sync(vec![create_task_sql])
695                .await?;
696            self.jdbc_client
697                .execute_sql_sync(vec![start_task_sql])
698                .await?;
699        }
700        Ok(())
701    }
702
703    pub async fn execute_drop_task(&self) -> Result<()> {
704        if self.snowflake_task_context.task_name.is_some() {
705            let sql = build_drop_task_sql(&self.snowflake_task_context);
706            if let Err(e) = self.jdbc_client.execute_sql_sync(vec![sql]).await {
707                tracing::error!(
708                    "Failed to drop Snowflake sink task {:?}: {:?}",
709                    self.snowflake_task_context.task_name,
710                    e.as_report()
711                );
712            } else {
713                tracing::info!(
714                    "Snowflake sink task {:?} dropped",
715                    self.snowflake_task_context.task_name
716                );
717            }
718        }
719        Ok(())
720    }
721
722    pub async fn execute_create_table(&self) -> Result<()> {
723        // create target table
724        let create_target_table_sql = build_create_table_sql(
725            &self.snowflake_task_context.target_table_name,
726            &self.snowflake_task_context.database,
727            &self.snowflake_task_context.schema_name,
728            &self.snowflake_task_context.schema,
729            false,
730        )?;
731        self.jdbc_client
732            .execute_sql_sync(vec![create_target_table_sql])
733            .await?;
734        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
735            let create_cdc_table_sql = build_create_table_sql(
736                cdc_table_name,
737                &self.snowflake_task_context.database,
738                &self.snowflake_task_context.schema_name,
739                &self.snowflake_task_context.schema,
740                true,
741            )?;
742            self.jdbc_client
743                .execute_sql_sync(vec![create_cdc_table_sql])
744                .await?;
745        }
746        Ok(())
747    }
748
749    pub async fn execute_create_pipe(&self) -> Result<()> {
750        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
751            let table_name =
752                if let Some(table_name) = self.snowflake_task_context.cdc_table_name.as_ref() {
753                    table_name
754                } else {
755                    &self.snowflake_task_context.target_table_name
756                };
757            let create_pipe_sql = build_create_pipe_sql(
758                table_name,
759                &self.snowflake_task_context.database,
760                &self.snowflake_task_context.schema_name,
761                self.snowflake_task_context.stage.as_ref().ok_or_else(|| {
762                    SinkError::Config(anyhow!("snowflake.stage is required for S3 writer"))
763                })?,
764                pipe_name,
765                &self.snowflake_task_context.target_table_name,
766            );
767            self.jdbc_client
768                .execute_sql_sync(vec![create_pipe_sql])
769                .await?;
770        }
771        Ok(())
772    }
773
774    pub async fn execute_flush_pipe(&self) -> Result<()> {
775        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
776            let flush_pipe_sql = build_flush_pipe_sql(
777                &self.snowflake_task_context.database,
778                &self.snowflake_task_context.schema_name,
779                pipe_name,
780            );
781            self.jdbc_client
782                .execute_sql_sync(vec![flush_pipe_sql])
783                .await?;
784        }
785        Ok(())
786    }
787}
788
789fn build_create_table_sql(
790    table_name: &str,
791    database: &str,
792    schema_name: &str,
793    schema: &Schema,
794    need_op_and_row_id: bool,
795) -> Result<String> {
796    let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema_name, table_name);
797    let mut columns: Vec<String> = schema
798        .fields
799        .iter()
800        .map(|field| {
801            let data_type = convert_snowflake_data_type(&field.data_type)?;
802            Ok(format!(r#""{}" {}"#, field.name, data_type))
803        })
804        .collect::<Result<Vec<String>>>()?;
805    if need_op_and_row_id {
806        columns.push(format!(r#""{}" STRING"#, SNOWFLAKE_SINK_ROW_ID));
807        columns.push(format!(r#""{}" INT"#, SNOWFLAKE_SINK_OP));
808    }
809    let columns_str = columns.join(", ");
810    Ok(format!(
811        "CREATE TABLE IF NOT EXISTS {} ({}) ENABLE_SCHEMA_EVOLUTION  = true",
812        full_table_name, columns_str
813    ))
814}
815
816fn convert_snowflake_data_type(data_type: &DataType) -> Result<String> {
817    let data_type = match data_type {
818        DataType::Int16 => "SMALLINT".to_owned(),
819        DataType::Int32 => "INTEGER".to_owned(),
820        DataType::Int64 => "BIGINT".to_owned(),
821        DataType::Float32 => "FLOAT4".to_owned(),
822        DataType::Float64 => "FLOAT8".to_owned(),
823        DataType::Boolean => "BOOLEAN".to_owned(),
824        DataType::Varchar => "STRING".to_owned(),
825        DataType::Date => "DATE".to_owned(),
826        DataType::Timestamp => "TIMESTAMP".to_owned(),
827        DataType::Timestamptz => "TIMESTAMP_TZ".to_owned(),
828        DataType::Jsonb => "STRING".to_owned(),
829        DataType::Decimal => "DECIMAL".to_owned(),
830        DataType::Bytea => "BINARY".to_owned(),
831        DataType::Time => "TIME".to_owned(),
832        _ => {
833            return Err(SinkError::Config(anyhow!(
834                "Dont support auto create table for datatype: {}",
835                data_type
836            )));
837        }
838    };
839    Ok(data_type)
840}
841
842fn build_create_pipe_sql(
843    table_name: &str,
844    database: &str,
845    schema: &str,
846    stage: &str,
847    pipe_name: &str,
848    target_table_name: &str,
849) -> String {
850    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
851    let stage = format!(
852        r#""{}"."{}"."{}"/{}"#,
853        database, schema, stage, target_table_name
854    );
855    let table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
856    format!(
857        "CREATE OR REPLACE PIPE {} AUTO_INGEST = FALSE AS COPY INTO {} FROM @{} MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE FILE_FORMAT = (type = 'JSON');",
858        pipe_name, table_name, stage
859    )
860}
861
862fn build_flush_pipe_sql(database: &str, schema: &str, pipe_name: &str) -> String {
863    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
864    format!("ALTER PIPE {} REFRESH;", pipe_name,)
865}
866
867fn build_alter_add_column_sql(
868    table_name: &str,
869    database: &str,
870    schema: &str,
871    columns: &Vec<(String, String)>,
872) -> String {
873    let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
874    jdbc_jni_client::build_alter_add_column_sql(&full_table_name, columns, true)
875}
876
877fn build_start_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
878    let SnowflakeTaskContext {
879        task_name,
880        database,
881        schema_name: schema,
882        ..
883    } = snowflake_task_context;
884    let full_task_name = format!(
885        r#""{}"."{}"."{}""#,
886        database,
887        schema,
888        task_name.as_ref().unwrap()
889    );
890    format!("ALTER TASK {} RESUME", full_task_name)
891}
892
893fn build_drop_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
894    let SnowflakeTaskContext {
895        task_name,
896        database,
897        schema_name: schema,
898        ..
899    } = snowflake_task_context;
900    let full_task_name = format!(
901        r#""{}"."{}"."{}""#,
902        database,
903        schema,
904        task_name.as_ref().unwrap()
905    );
906    format!("DROP TASK IF EXISTS {}", full_task_name)
907}
908
909fn build_create_merge_into_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
910    let SnowflakeTaskContext {
911        task_name,
912        cdc_table_name,
913        target_table_name,
914        schedule_seconds,
915        warehouse,
916        pk_column_names,
917        all_column_names,
918        database,
919        schema_name,
920        ..
921    } = snowflake_task_context;
922    let full_task_name = format!(
923        r#""{}"."{}"."{}""#,
924        database,
925        schema_name,
926        task_name.as_ref().unwrap()
927    );
928    let full_cdc_table_name = format!(
929        r#""{}"."{}"."{}""#,
930        database,
931        schema_name,
932        cdc_table_name.as_ref().unwrap()
933    );
934    let full_target_table_name = format!(
935        r#""{}"."{}"."{}""#,
936        database, schema_name, target_table_name
937    );
938
939    let pk_names_str = pk_column_names
940        .as_ref()
941        .unwrap()
942        .iter()
943        .map(|name| format!(r#""{}""#, name))
944        .collect::<Vec<String>>()
945        .join(", ");
946    let pk_names_eq_str = pk_column_names
947        .as_ref()
948        .unwrap()
949        .iter()
950        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
951        .collect::<Vec<String>>()
952        .join(" AND ");
953    let all_column_names_set_str = all_column_names
954        .as_ref()
955        .unwrap()
956        .iter()
957        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
958        .collect::<Vec<String>>()
959        .join(", ");
960    let all_column_names_str = all_column_names
961        .as_ref()
962        .unwrap()
963        .iter()
964        .map(|name| format!(r#""{}""#, name))
965        .collect::<Vec<String>>()
966        .join(", ");
967    let all_column_names_insert_str = all_column_names
968        .as_ref()
969        .unwrap()
970        .iter()
971        .map(|name| format!(r#"source."{}""#, name))
972        .collect::<Vec<String>>()
973        .join(", ");
974
975    format!(
976        r#"CREATE OR REPLACE TASK {task_name}
977WAREHOUSE = {warehouse}
978SCHEDULE = '{schedule_seconds} SECONDS'
979AS
980BEGIN
981    LET max_row_id STRING;
982
983    SELECT COALESCE(MAX("{snowflake_sink_row_id}"), '0') INTO :max_row_id
984    FROM {cdc_table_name};
985
986    MERGE INTO {target_table_name} AS target
987    USING (
988        SELECT *
989        FROM (
990            SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_names_str} ORDER BY "{snowflake_sink_row_id}" DESC) AS dedupe_id
991            FROM {cdc_table_name}
992            WHERE "{snowflake_sink_row_id}" <= :max_row_id
993        ) AS subquery
994        WHERE dedupe_id = 1
995    ) AS source
996    ON {pk_names_eq_str}
997    WHEN MATCHED AND source."{snowflake_sink_op}" IN (2, 4) THEN DELETE
998    WHEN MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN UPDATE SET {all_column_names_set_str}
999    WHEN NOT MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN INSERT ({all_column_names_str}) VALUES ({all_column_names_insert_str});
1000
1001    DELETE FROM {cdc_table_name}
1002    WHERE "{snowflake_sink_row_id}" <= :max_row_id;
1003END;"#,
1004        task_name = full_task_name,
1005        warehouse = warehouse.as_ref().unwrap(),
1006        schedule_seconds = schedule_seconds,
1007        cdc_table_name = full_cdc_table_name,
1008        target_table_name = full_target_table_name,
1009        pk_names_str = pk_names_str,
1010        pk_names_eq_str = pk_names_eq_str,
1011        all_column_names_set_str = all_column_names_set_str,
1012        all_column_names_str = all_column_names_str,
1013        all_column_names_insert_str = all_column_names_insert_str,
1014        snowflake_sink_row_id = SNOWFLAKE_SINK_ROW_ID,
1015        snowflake_sink_op = SNOWFLAKE_SINK_OP,
1016    )
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021    use super::*;
1022    use crate::sink::jdbc_jni_client::normalize_sql;
1023
1024    #[test]
1025    fn test_snowflake_sink_commit_coordinator() {
1026        let snowflake_task_context = SnowflakeTaskContext {
1027            task_name: Some("test_task".to_owned()),
1028            cdc_table_name: Some("test_cdc_table".to_owned()),
1029            target_table_name: "test_target_table".to_owned(),
1030            schedule_seconds: 3600,
1031            warehouse: Some("test_warehouse".to_owned()),
1032            pk_column_names: Some(vec!["v1".to_owned()]),
1033            all_column_names: Some(vec!["v1".to_owned(), "v2".to_owned()]),
1034            database: "test_db".to_owned(),
1035            schema_name: "test_schema".to_owned(),
1036            schema: Schema { fields: vec![] },
1037            stage: None,
1038            pipe_name: None,
1039        };
1040        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1041        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task"
1042WAREHOUSE = test_warehouse
1043SCHEDULE = '3600 SECONDS'
1044AS
1045BEGIN
1046    LET max_row_id STRING;
1047
1048    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1049    FROM "test_db"."test_schema"."test_cdc_table";
1050
1051    MERGE INTO "test_db"."test_schema"."test_target_table" AS target
1052    USING (
1053        SELECT *
1054        FROM (
1055            SELECT *, ROW_NUMBER() OVER (PARTITION BY "v1" ORDER BY "__row_id" DESC) AS dedupe_id
1056            FROM "test_db"."test_schema"."test_cdc_table"
1057            WHERE "__row_id" <= :max_row_id
1058        ) AS subquery
1059        WHERE dedupe_id = 1
1060    ) AS source
1061    ON target."v1" = source."v1"
1062    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1063    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."v1" = source."v1", target."v2" = source."v2"
1064    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("v1", "v2") VALUES (source."v1", source."v2");
1065
1066    DELETE FROM "test_db"."test_schema"."test_cdc_table"
1067    WHERE "__row_id" <= :max_row_id;
1068END;"#;
1069        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1070    }
1071
1072    #[test]
1073    fn test_snowflake_sink_commit_coordinator_multi_pk() {
1074        let snowflake_task_context = SnowflakeTaskContext {
1075            task_name: Some("test_task_multi_pk".to_owned()),
1076            cdc_table_name: Some("cdc_multi_pk".to_owned()),
1077            target_table_name: "target_multi_pk".to_owned(),
1078            schedule_seconds: 300,
1079            warehouse: Some("multi_pk_warehouse".to_owned()),
1080            pk_column_names: Some(vec!["id1".to_owned(), "id2".to_owned()]),
1081            all_column_names: Some(vec!["id1".to_owned(), "id2".to_owned(), "val".to_owned()]),
1082            database: "test_db".to_owned(),
1083            schema_name: "test_schema".to_owned(),
1084            schema: Schema { fields: vec![] },
1085            stage: None,
1086            pipe_name: None,
1087        };
1088        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1089        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task_multi_pk"
1090WAREHOUSE = multi_pk_warehouse
1091SCHEDULE = '300 SECONDS'
1092AS
1093BEGIN
1094    LET max_row_id STRING;
1095
1096    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1097    FROM "test_db"."test_schema"."cdc_multi_pk";
1098
1099    MERGE INTO "test_db"."test_schema"."target_multi_pk" AS target
1100    USING (
1101        SELECT *
1102        FROM (
1103            SELECT *, ROW_NUMBER() OVER (PARTITION BY "id1", "id2" ORDER BY "__row_id" DESC) AS dedupe_id
1104            FROM "test_db"."test_schema"."cdc_multi_pk"
1105            WHERE "__row_id" <= :max_row_id
1106        ) AS subquery
1107        WHERE dedupe_id = 1
1108    ) AS source
1109    ON target."id1" = source."id1" AND target."id2" = source."id2"
1110    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1111    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."id1" = source."id1", target."id2" = source."id2", target."val" = source."val"
1112    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("id1", "id2", "val") VALUES (source."id1", source."id2", source."val");
1113
1114    DELETE FROM "test_db"."test_schema"."cdc_multi_pk"
1115    WHERE "__row_id" <= :max_row_id;
1116END;"#;
1117        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1118    }
1119}