risingwave_connector/sink/snowflake_redshift/
snowflake.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::num::NonZeroU64;
16use std::collections::BTreeMap;
17
18use anyhow::anyhow;
19use phf::{Set, phf_set};
20use risingwave_common::array::StreamChunk;
21use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
22use risingwave_common::types::DataType;
23use risingwave_pb::connector_service::{SinkMetadata, sink_metadata};
24use sea_orm::DatabaseConnection;
25use serde::Deserialize;
26use serde_with::{DisplayFromStr, serde_as};
27use thiserror_ext::AsReport;
28use tokio::sync::mpsc::UnboundedSender;
29use tonic::async_trait;
30use with_options::WithOptions;
31
32use crate::connector_common::IcebergSinkCompactionUpdate;
33use crate::enforce_secret::EnforceSecret;
34use crate::sink::coordinate::CoordinatedLogSinker;
35use crate::sink::decouple_checkpoint_log_sink::default_commit_checkpoint_interval;
36use crate::sink::file_sink::s3::S3Common;
37use crate::sink::jdbc_jni_client::{self, JdbcJniClient};
38use crate::sink::remote::CoordinatedRemoteSinkWriter;
39use crate::sink::snowflake_redshift::{AugmentedChunk, SnowflakeRedshiftSinkS3Writer};
40use crate::sink::writer::SinkWriter;
41use crate::sink::{
42    Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, Sink, SinkCommitCoordinator,
43    SinkCommittedEpochSubscriber, SinkError, SinkParam, SinkWriterMetrics, SinkWriterParam,
44};
45
46pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2";
47pub const SNOWFLAKE_SINK_ROW_ID: &str = "__row_id";
48pub const SNOWFLAKE_SINK_OP: &str = "__op";
49
50#[serde_as]
51#[derive(Debug, Clone, Deserialize, WithOptions)]
52pub struct SnowflakeV2Config {
53    #[serde(rename = "type")]
54    pub r#type: String,
55
56    #[serde(rename = "intermediate.table.name")]
57    pub snowflake_cdc_table_name: Option<String>,
58
59    #[serde(rename = "table.name")]
60    pub snowflake_target_table_name: Option<String>,
61
62    #[serde(rename = "database")]
63    pub snowflake_database: Option<String>,
64
65    #[serde(rename = "schema")]
66    pub snowflake_schema: Option<String>,
67
68    #[serde(default = "default_schedule")]
69    #[serde(rename = "write.target.interval.seconds")]
70    #[serde_as(as = "DisplayFromStr")]
71    pub snowflake_schedule_seconds: u64,
72
73    #[serde(rename = "warehouse")]
74    pub snowflake_warehouse: Option<String>,
75
76    #[serde(rename = "jdbc.url")]
77    pub jdbc_url: Option<String>,
78
79    #[serde(rename = "username")]
80    pub username: Option<String>,
81
82    #[serde(rename = "password")]
83    pub password: Option<String>,
84
85    /// Commit every n(>0) checkpoints, default is 10.
86    #[serde(default = "default_commit_checkpoint_interval")]
87    #[serde_as(as = "DisplayFromStr")]
88    #[with_option(allow_alter_on_fly)]
89    pub commit_checkpoint_interval: u64,
90
91    /// Enable auto schema change for upsert sink.
92    /// If enabled, the sink will automatically alter the target table to add new columns.
93    #[serde(default)]
94    #[serde(rename = "auto.schema.change")]
95    #[serde_as(as = "DisplayFromStr")]
96    pub auto_schema_change: bool,
97
98    #[serde(default)]
99    #[serde(rename = "create_table_if_not_exists")]
100    #[serde_as(as = "DisplayFromStr")]
101    pub create_table_if_not_exists: bool,
102
103    #[serde(default = "default_with_s3")]
104    #[serde(rename = "with_s3")]
105    #[serde_as(as = "DisplayFromStr")]
106    pub with_s3: bool,
107
108    #[serde(flatten)]
109    pub s3_inner: Option<S3Common>,
110
111    #[serde(rename = "stage")]
112    pub stage: Option<String>,
113}
114
115fn default_schedule() -> u64 {
116    3600 // Default to 1 hour
117}
118
119fn default_with_s3() -> bool {
120    true
121}
122
123impl SnowflakeV2Config {
124    pub fn from_btreemap(properties: &BTreeMap<String, String>) -> Result<Self> {
125        let config =
126            serde_json::from_value::<SnowflakeV2Config>(serde_json::to_value(properties).unwrap())
127                .map_err(|e| SinkError::Config(anyhow!(e)))?;
128        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
129            return Err(SinkError::Config(anyhow!(
130                "`{}` must be {}, or {}",
131                SINK_TYPE_OPTION,
132                SINK_TYPE_APPEND_ONLY,
133                SINK_TYPE_UPSERT
134            )));
135        }
136        Ok(config)
137    }
138
139    pub fn build_snowflake_task_ctx_jdbc_client(
140        &self,
141        is_append_only: bool,
142        schema: &Schema,
143        pk_indices: &Vec<usize>,
144    ) -> Result<Option<(SnowflakeTaskContext, JdbcJniClient)>> {
145        if !self.auto_schema_change && is_append_only && !self.create_table_if_not_exists {
146            // append-only + no auto schema change is not need to create a client
147            return Ok(None);
148        }
149        let target_table_name = self
150            .snowflake_target_table_name
151            .clone()
152            .ok_or(SinkError::Config(anyhow!("table.name is required")))?;
153        let database = self
154            .snowflake_database
155            .clone()
156            .ok_or(SinkError::Config(anyhow!("database is required")))?;
157        let schema_name = self
158            .snowflake_schema
159            .clone()
160            .ok_or(SinkError::Config(anyhow!("schema is required")))?;
161        let mut snowflake_task_ctx = SnowflakeTaskContext {
162            target_table_name: target_table_name.clone(),
163            database,
164            schema_name,
165            schema: schema.clone(),
166            ..Default::default()
167        };
168
169        let jdbc_url = self
170            .jdbc_url
171            .clone()
172            .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?;
173        let username = self
174            .username
175            .clone()
176            .ok_or(SinkError::Config(anyhow!("username is required")))?;
177        let password = self
178            .password
179            .clone()
180            .ok_or(SinkError::Config(anyhow!("password is required")))?;
181        let jdbc_url = format!("{}?user={}&password={}", jdbc_url, username, password);
182        let client = JdbcJniClient::new(jdbc_url)?;
183
184        if self.with_s3 {
185            let stage = self
186                .stage
187                .clone()
188                .ok_or(SinkError::Config(anyhow!("stage is required")))?;
189            snowflake_task_ctx.stage = Some(stage);
190            snowflake_task_ctx.pipe_name = Some(format!("{}_pipe", target_table_name));
191        }
192        if !is_append_only {
193            let cdc_table_name = self
194                .snowflake_cdc_table_name
195                .clone()
196                .ok_or(SinkError::Config(anyhow!(
197                    "intermediate.table.name is required"
198                )))?;
199            snowflake_task_ctx.cdc_table_name = Some(cdc_table_name.clone());
200            snowflake_task_ctx.schedule_seconds = self.snowflake_schedule_seconds;
201            snowflake_task_ctx.warehouse = Some(
202                self.snowflake_warehouse
203                    .clone()
204                    .ok_or(SinkError::Config(anyhow!("warehouse is required")))?,
205            );
206            let pk_column_names: Vec<_> = schema
207                .fields
208                .iter()
209                .enumerate()
210                .filter(|(index, _)| pk_indices.contains(index))
211                .map(|(_, field)| field.name.clone())
212                .collect();
213            if pk_column_names.is_empty() {
214                return Err(SinkError::Config(anyhow!(
215                    "Primary key columns not found. Please set the `primary_key` column in the sink properties, or ensure that the sink contains the primary key columns from the upstream."
216                )));
217            }
218            snowflake_task_ctx.pk_column_names = Some(pk_column_names);
219            snowflake_task_ctx.all_column_names = Some(
220                schema
221                    .fields
222                    .iter()
223                    .map(|field| field.name.clone())
224                    .collect(),
225            );
226            snowflake_task_ctx.task_name = Some(format!(
227                "rw_snowflake_sink_from_{cdc_table_name}_to_{target_table_name}"
228            ));
229        }
230        Ok(Some((snowflake_task_ctx, client)))
231    }
232}
233
234impl EnforceSecret for SnowflakeV2Config {
235    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
236        "username",
237        "password",
238        "jdbc.url",
239    };
240}
241
242#[derive(Clone, Debug)]
243pub struct SnowflakeV2Sink {
244    config: SnowflakeV2Config,
245    schema: Schema,
246    pk_indices: Vec<usize>,
247    is_append_only: bool,
248    param: SinkParam,
249}
250
251impl EnforceSecret for SnowflakeV2Sink {
252    fn enforce_secret<'a>(
253        prop_iter: impl Iterator<Item = &'a str>,
254    ) -> crate::sink::ConnectorResult<()> {
255        for prop in prop_iter {
256            SnowflakeV2Config::enforce_one(prop)?;
257        }
258        Ok(())
259    }
260}
261
262impl TryFrom<SinkParam> for SnowflakeV2Sink {
263    type Error = SinkError;
264
265    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
266        let schema = param.schema();
267        let config = SnowflakeV2Config::from_btreemap(&param.properties)?;
268        let is_append_only = param.sink_type.is_append_only();
269        let pk_indices = param.downstream_pk_or_empty();
270        Ok(Self {
271            config,
272            schema,
273            pk_indices,
274            is_append_only,
275            param,
276        })
277    }
278}
279
280impl Sink for SnowflakeV2Sink {
281    type Coordinator = SnowflakeSinkCommitter;
282    type LogSinker = CoordinatedLogSinker<SnowflakeSinkWriter>;
283
284    const SINK_NAME: &'static str = SNOWFLAKE_SINK_V2;
285
286    async fn validate(&self) -> Result<()> {
287        risingwave_common::license::Feature::SnowflakeSink
288            .check_available()
289            .map_err(|e| anyhow::anyhow!(e))?;
290        if let Some((snowflake_task_ctx, client)) =
291            self.config.build_snowflake_task_ctx_jdbc_client(
292                self.is_append_only,
293                &self.schema,
294                &self.pk_indices,
295            )?
296        {
297            let client = SnowflakeJniClient::new(client, snowflake_task_ctx);
298            client.execute_create_table().await?;
299        }
300
301        Ok(())
302    }
303
304    fn support_schema_change() -> bool {
305        true
306    }
307
308    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
309        SnowflakeV2Config::from_btreemap(config)?;
310        Ok(())
311    }
312
313    async fn new_log_sinker(
314        &self,
315        writer_param: crate::sink::SinkWriterParam,
316    ) -> Result<Self::LogSinker> {
317        let writer = SnowflakeSinkWriter::new(
318            self.config.clone(),
319            self.is_append_only,
320            writer_param.clone(),
321            self.param.clone(),
322        )
323        .await?;
324
325        let commit_checkpoint_interval =
326            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
327                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
328            );
329
330        CoordinatedLogSinker::new(
331            &writer_param,
332            self.param.clone(),
333            writer,
334            commit_checkpoint_interval,
335        )
336        .await
337    }
338
339    fn is_coordinated_sink(&self) -> bool {
340        true
341    }
342
343    async fn new_coordinator(
344        &self,
345        _db: DatabaseConnection,
346        _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
347    ) -> Result<Self::Coordinator> {
348        let coordinator = SnowflakeSinkCommitter::new(
349            self.config.clone(),
350            &self.schema,
351            &self.pk_indices,
352            self.is_append_only,
353        )?;
354        Ok(coordinator)
355    }
356}
357
358pub enum SnowflakeSinkWriter {
359    S3(SnowflakeRedshiftSinkS3Writer),
360    Jdbc(SnowflakeSinkJdbcWriter),
361}
362
363impl SnowflakeSinkWriter {
364    pub async fn new(
365        config: SnowflakeV2Config,
366        is_append_only: bool,
367        writer_param: SinkWriterParam,
368        param: SinkParam,
369    ) -> Result<Self> {
370        let schema = param.schema();
371        if config.with_s3 {
372            let executor_id = writer_param.executor_id;
373            let s3_writer = SnowflakeRedshiftSinkS3Writer::new(
374                config.s3_inner.ok_or_else(|| {
375                    SinkError::Config(anyhow!(
376                        "S3 configuration is required for Snowflake S3 sink"
377                    ))
378                })?,
379                schema,
380                is_append_only,
381                executor_id,
382                config.snowflake_target_table_name,
383            )?;
384            Ok(Self::S3(s3_writer))
385        } else {
386            let jdbc_writer =
387                SnowflakeSinkJdbcWriter::new(config, is_append_only, writer_param, param).await?;
388            Ok(Self::Jdbc(jdbc_writer))
389        }
390    }
391}
392
393#[async_trait]
394impl SinkWriter for SnowflakeSinkWriter {
395    type CommitMetadata = Option<SinkMetadata>;
396
397    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
398        match self {
399            Self::S3(writer) => writer.begin_epoch(epoch),
400            Self::Jdbc(writer) => writer.begin_epoch(epoch).await,
401        }
402    }
403
404    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
405        match self {
406            Self::S3(writer) => writer.write_batch(chunk).await,
407            Self::Jdbc(writer) => writer.write_batch(chunk).await,
408        }
409    }
410
411    async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
412        match self {
413            Self::S3(writer) => {
414                writer.barrier(is_checkpoint).await?;
415            }
416            Self::Jdbc(writer) => {
417                writer.barrier(is_checkpoint).await?;
418            }
419        }
420        Ok(Some(SinkMetadata {
421            metadata: Some(sink_metadata::Metadata::Serialized(
422                risingwave_pb::connector_service::sink_metadata::SerializedMetadata {
423                    metadata: vec![],
424                },
425            )),
426        }))
427    }
428
429    async fn abort(&mut self) -> Result<()> {
430        if let Self::Jdbc(writer) = self {
431            writer.abort().await
432        } else {
433            Ok(())
434        }
435    }
436}
437
438pub struct SnowflakeSinkJdbcWriter {
439    augmented_row: AugmentedChunk,
440    jdbc_sink_writer: CoordinatedRemoteSinkWriter,
441}
442
443impl SnowflakeSinkJdbcWriter {
444    pub async fn new(
445        config: SnowflakeV2Config,
446        is_append_only: bool,
447        writer_param: SinkWriterParam,
448        mut param: SinkParam,
449    ) -> Result<Self> {
450        let metrics = SinkWriterMetrics::new(&writer_param);
451        let properties = &param.properties;
452        let column_descs = &mut param.columns;
453        let full_table_name = if is_append_only {
454            format!(
455                r#""{}"."{}"."{}""#,
456                config.snowflake_database.clone().unwrap_or_default(),
457                config.snowflake_schema.clone().unwrap_or_default(),
458                config
459                    .snowflake_target_table_name
460                    .clone()
461                    .unwrap_or_default()
462            )
463        } else {
464            let max_column_id = column_descs
465                .iter()
466                .map(|column| column.column_id.get_id())
467                .max()
468                .unwrap_or(0);
469            (*column_descs).push(ColumnDesc::named(
470                SNOWFLAKE_SINK_ROW_ID,
471                ColumnId::new(max_column_id + 1),
472                DataType::Varchar,
473            ));
474            (*column_descs).push(ColumnDesc::named(
475                SNOWFLAKE_SINK_OP,
476                ColumnId::new(max_column_id + 2),
477                DataType::Int32,
478            ));
479            format!(
480                r#""{}"."{}"."{}""#,
481                config.snowflake_database.clone().unwrap_or_default(),
482                config.snowflake_schema.clone().unwrap_or_default(),
483                config.snowflake_cdc_table_name.clone().unwrap_or_default()
484            )
485        };
486        let new_properties = BTreeMap::from([
487            ("table.name".to_owned(), full_table_name),
488            ("connector".to_owned(), "snowflake_v2".to_owned()),
489            (
490                "jdbc.url".to_owned(),
491                config.jdbc_url.clone().unwrap_or_default(),
492            ),
493            ("type".to_owned(), "append-only".to_owned()),
494            (
495                "user".to_owned(),
496                config.username.clone().unwrap_or_default(),
497            ),
498            (
499                "password".to_owned(),
500                config.password.clone().unwrap_or_default(),
501            ),
502            (
503                "primary_key".to_owned(),
504                properties.get("primary_key").cloned().unwrap_or_default(),
505            ),
506            (
507                "schema.name".to_owned(),
508                config.snowflake_schema.clone().unwrap_or_default(),
509            ),
510            (
511                "database.name".to_owned(),
512                config.snowflake_database.clone().unwrap_or_default(),
513            ),
514        ]);
515        param.properties = new_properties;
516
517        let jdbc_sink_writer =
518            CoordinatedRemoteSinkWriter::new(param.clone(), metrics.clone()).await?;
519        Ok(Self {
520            augmented_row: AugmentedChunk::new(0, is_append_only),
521            jdbc_sink_writer,
522        })
523    }
524}
525
526impl SnowflakeSinkJdbcWriter {
527    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
528        self.augmented_row.reset_epoch(epoch);
529        self.jdbc_sink_writer.begin_epoch(epoch).await?;
530        Ok(())
531    }
532
533    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
534        let chunk = self.augmented_row.augmented_chunk(chunk)?;
535        self.jdbc_sink_writer.write_batch(chunk).await?;
536        Ok(())
537    }
538
539    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
540        self.jdbc_sink_writer.barrier(is_checkpoint).await?;
541        Ok(())
542    }
543
544    async fn abort(&mut self) -> Result<()> {
545        // TODO: abort should clean up all the data written in this epoch.
546        self.jdbc_sink_writer.abort().await?;
547        Ok(())
548    }
549}
550
551#[derive(Default)]
552pub struct SnowflakeTaskContext {
553    // required for task creation
554    pub target_table_name: String,
555    pub database: String,
556    pub schema_name: String,
557    pub schema: Schema,
558
559    // only upsert
560    pub task_name: Option<String>,
561    pub cdc_table_name: Option<String>,
562    pub schedule_seconds: u64,
563    pub warehouse: Option<String>,
564    pub pk_column_names: Option<Vec<String>>,
565    pub all_column_names: Option<Vec<String>>,
566
567    // only s3 writer
568    pub stage: Option<String>,
569    pub pipe_name: Option<String>,
570}
571pub struct SnowflakeSinkCommitter {
572    client: Option<SnowflakeJniClient>,
573}
574
575impl SnowflakeSinkCommitter {
576    pub fn new(
577        config: SnowflakeV2Config,
578        schema: &Schema,
579        pk_indices: &Vec<usize>,
580        is_append_only: bool,
581    ) -> Result<Self> {
582        let client = if let Some((snowflake_task_ctx, client)) =
583            config.build_snowflake_task_ctx_jdbc_client(is_append_only, schema, pk_indices)?
584        {
585            Some(SnowflakeJniClient::new(client, snowflake_task_ctx))
586        } else {
587            None
588        };
589        Ok(Self { client })
590    }
591}
592
593#[async_trait]
594impl SinkCommitCoordinator for SnowflakeSinkCommitter {
595    async fn init(&mut self, _subscriber: SinkCommittedEpochSubscriber) -> Result<Option<u64>> {
596        if let Some(client) = &self.client {
597            // Todo: move this to validate
598            client.execute_create_pipe().await?;
599            client.execute_create_merge_into_task().await?;
600        }
601        Ok(None)
602    }
603
604    async fn commit(
605        &mut self,
606        _epoch: u64,
607        _metadata: Vec<SinkMetadata>,
608        add_columns: Option<Vec<Field>>,
609    ) -> Result<()> {
610        let client = self.client.as_mut().ok_or_else(|| {
611            SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
612        })?;
613        client.execute_flush_pipe().await?;
614
615        if let Some(add_columns) = add_columns {
616            client
617                .execute_alter_add_columns(
618                    &add_columns
619                        .iter()
620                        .map(|f| (f.name.clone(), f.data_type.to_string()))
621                        .collect::<Vec<_>>(),
622                )
623                .await?;
624        }
625        Ok(())
626    }
627}
628
629impl Drop for SnowflakeSinkCommitter {
630    fn drop(&mut self) {
631        if let Some(client) = self.client.take() {
632            tokio::spawn(async move {
633                client.execute_drop_task().await.ok();
634            });
635        }
636    }
637}
638
639pub struct SnowflakeJniClient {
640    jdbc_client: JdbcJniClient,
641    snowflake_task_context: SnowflakeTaskContext,
642}
643
644impl SnowflakeJniClient {
645    pub fn new(jdbc_client: JdbcJniClient, snowflake_task_context: SnowflakeTaskContext) -> Self {
646        Self {
647            jdbc_client,
648            snowflake_task_context,
649        }
650    }
651
652    pub async fn execute_alter_add_columns(
653        &mut self,
654        columns: &Vec<(String, String)>,
655    ) -> Result<()> {
656        self.execute_drop_task().await?;
657        if let Some(names) = self.snowflake_task_context.all_column_names.as_mut() {
658            names.extend(columns.iter().map(|(name, _)| name.clone()));
659        }
660        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
661            let alter_add_column_cdc_table_sql = build_alter_add_column_sql(
662                cdc_table_name,
663                &self.snowflake_task_context.database,
664                &self.snowflake_task_context.schema_name,
665                columns,
666            );
667            self.jdbc_client
668                .execute_sql_sync(vec![alter_add_column_cdc_table_sql])
669                .await?;
670        }
671
672        let alter_add_column_target_table_sql = build_alter_add_column_sql(
673            &self.snowflake_task_context.target_table_name,
674            &self.snowflake_task_context.database,
675            &self.snowflake_task_context.schema_name,
676            columns,
677        );
678        self.jdbc_client
679            .execute_sql_sync(vec![alter_add_column_target_table_sql])
680            .await?;
681
682        self.execute_create_merge_into_task().await?;
683        Ok(())
684    }
685
686    pub async fn execute_create_merge_into_task(&self) -> Result<()> {
687        if self.snowflake_task_context.task_name.is_some() {
688            let create_task_sql = build_create_merge_into_task_sql(&self.snowflake_task_context);
689            let start_task_sql = build_start_task_sql(&self.snowflake_task_context);
690            self.jdbc_client
691                .execute_sql_sync(vec![create_task_sql])
692                .await?;
693            self.jdbc_client
694                .execute_sql_sync(vec![start_task_sql])
695                .await?;
696        }
697        Ok(())
698    }
699
700    pub async fn execute_drop_task(&self) -> Result<()> {
701        if self.snowflake_task_context.task_name.is_some() {
702            let sql = build_drop_task_sql(&self.snowflake_task_context);
703            if let Err(e) = self.jdbc_client.execute_sql_sync(vec![sql]).await {
704                tracing::error!(
705                    "Failed to drop Snowflake sink task {:?}: {:?}",
706                    self.snowflake_task_context.task_name,
707                    e.as_report()
708                );
709            } else {
710                tracing::info!(
711                    "Snowflake sink task {:?} dropped",
712                    self.snowflake_task_context.task_name
713                );
714            }
715        }
716        Ok(())
717    }
718
719    pub async fn execute_create_table(&self) -> Result<()> {
720        // create target table
721        let create_target_table_sql = build_create_table_sql(
722            &self.snowflake_task_context.target_table_name,
723            &self.snowflake_task_context.database,
724            &self.snowflake_task_context.schema_name,
725            &self.snowflake_task_context.schema,
726            false,
727        )?;
728        self.jdbc_client
729            .execute_sql_sync(vec![create_target_table_sql])
730            .await?;
731        if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
732            let create_cdc_table_sql = build_create_table_sql(
733                cdc_table_name,
734                &self.snowflake_task_context.database,
735                &self.snowflake_task_context.schema_name,
736                &self.snowflake_task_context.schema,
737                true,
738            )?;
739            self.jdbc_client
740                .execute_sql_sync(vec![create_cdc_table_sql])
741                .await?;
742        }
743        Ok(())
744    }
745
746    pub async fn execute_create_pipe(&self) -> Result<()> {
747        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
748            let table_name =
749                if let Some(table_name) = self.snowflake_task_context.cdc_table_name.as_ref() {
750                    table_name
751                } else {
752                    &self.snowflake_task_context.target_table_name
753                };
754            let create_pipe_sql = build_create_pipe_sql(
755                table_name,
756                &self.snowflake_task_context.database,
757                &self.snowflake_task_context.schema_name,
758                self.snowflake_task_context.stage.as_ref().ok_or_else(|| {
759                    SinkError::Config(anyhow!("snowflake.stage is required for S3 writer"))
760                })?,
761                pipe_name,
762                &self.snowflake_task_context.target_table_name,
763            );
764            self.jdbc_client
765                .execute_sql_sync(vec![create_pipe_sql])
766                .await?;
767        }
768        Ok(())
769    }
770
771    pub async fn execute_flush_pipe(&self) -> Result<()> {
772        if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
773            let flush_pipe_sql = build_flush_pipe_sql(
774                &self.snowflake_task_context.database,
775                &self.snowflake_task_context.schema_name,
776                pipe_name,
777            );
778            self.jdbc_client
779                .execute_sql_sync(vec![flush_pipe_sql])
780                .await?;
781        }
782        Ok(())
783    }
784}
785
786fn build_create_table_sql(
787    table_name: &str,
788    database: &str,
789    schema_name: &str,
790    schema: &Schema,
791    need_op_and_row_id: bool,
792) -> Result<String> {
793    let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema_name, table_name);
794    let mut columns: Vec<String> = schema
795        .fields
796        .iter()
797        .map(|field| {
798            let data_type = convert_snowflake_data_type(&field.data_type)?;
799            Ok(format!(r#""{}" {}"#, field.name, data_type))
800        })
801        .collect::<Result<Vec<String>>>()?;
802    if need_op_and_row_id {
803        columns.push(format!(r#""{}" STRING"#, SNOWFLAKE_SINK_ROW_ID));
804        columns.push(format!(r#""{}" INT"#, SNOWFLAKE_SINK_OP));
805    }
806    let columns_str = columns.join(", ");
807    Ok(format!(
808        "CREATE TABLE IF NOT EXISTS {} ({}) ENABLE_SCHEMA_EVOLUTION  = true",
809        full_table_name, columns_str
810    ))
811}
812
813fn convert_snowflake_data_type(data_type: &DataType) -> Result<String> {
814    let data_type = match data_type {
815        DataType::Int16 => "SMALLINT".to_owned(),
816        DataType::Int32 => "INTEGER".to_owned(),
817        DataType::Int64 => "BIGINT".to_owned(),
818        DataType::Float32 => "FLOAT4".to_owned(),
819        DataType::Float64 => "FLOAT8".to_owned(),
820        DataType::Boolean => "BOOLEAN".to_owned(),
821        DataType::Varchar => "STRING".to_owned(),
822        DataType::Date => "DATE".to_owned(),
823        DataType::Timestamp => "TIMESTAMP".to_owned(),
824        DataType::Timestamptz => "TIMESTAMP_TZ".to_owned(),
825        DataType::Jsonb => "STRING".to_owned(),
826        DataType::Decimal => "DECIMAL".to_owned(),
827        DataType::Bytea => "BINARY".to_owned(),
828        DataType::Time => "TIME".to_owned(),
829        _ => {
830            return Err(SinkError::Config(anyhow!(
831                "Dont support auto create table for datatype: {}",
832                data_type
833            )));
834        }
835    };
836    Ok(data_type)
837}
838
839fn build_create_pipe_sql(
840    table_name: &str,
841    database: &str,
842    schema: &str,
843    stage: &str,
844    pipe_name: &str,
845    target_table_name: &str,
846) -> String {
847    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
848    let stage = format!(
849        r#""{}"."{}"."{}"/{}"#,
850        database, schema, stage, target_table_name
851    );
852    let table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
853    format!(
854        "CREATE OR REPLACE PIPE {} AUTO_INGEST = FALSE AS COPY INTO {} FROM @{} MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE FILE_FORMAT = (type = 'JSON');",
855        pipe_name, table_name, stage
856    )
857}
858
859fn build_flush_pipe_sql(database: &str, schema: &str, pipe_name: &str) -> String {
860    let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
861    format!("ALTER PIPE {} REFRESH;", pipe_name,)
862}
863
864fn build_alter_add_column_sql(
865    table_name: &str,
866    database: &str,
867    schema: &str,
868    columns: &Vec<(String, String)>,
869) -> String {
870    let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
871    jdbc_jni_client::build_alter_add_column_sql(&full_table_name, columns, true)
872}
873
874fn build_start_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
875    let SnowflakeTaskContext {
876        task_name,
877        database,
878        schema_name: schema,
879        ..
880    } = snowflake_task_context;
881    let full_task_name = format!(
882        r#""{}"."{}"."{}""#,
883        database,
884        schema,
885        task_name.as_ref().unwrap()
886    );
887    format!("ALTER TASK {} RESUME", full_task_name)
888}
889
890fn build_drop_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
891    let SnowflakeTaskContext {
892        task_name,
893        database,
894        schema_name: schema,
895        ..
896    } = snowflake_task_context;
897    let full_task_name = format!(
898        r#""{}"."{}"."{}""#,
899        database,
900        schema,
901        task_name.as_ref().unwrap()
902    );
903    format!("DROP TASK IF EXISTS {}", full_task_name)
904}
905
906fn build_create_merge_into_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
907    let SnowflakeTaskContext {
908        task_name,
909        cdc_table_name,
910        target_table_name,
911        schedule_seconds,
912        warehouse,
913        pk_column_names,
914        all_column_names,
915        database,
916        schema_name,
917        ..
918    } = snowflake_task_context;
919    let full_task_name = format!(
920        r#""{}"."{}"."{}""#,
921        database,
922        schema_name,
923        task_name.as_ref().unwrap()
924    );
925    let full_cdc_table_name = format!(
926        r#""{}"."{}"."{}""#,
927        database,
928        schema_name,
929        cdc_table_name.as_ref().unwrap()
930    );
931    let full_target_table_name = format!(
932        r#""{}"."{}"."{}""#,
933        database, schema_name, target_table_name
934    );
935
936    let pk_names_str = pk_column_names
937        .as_ref()
938        .unwrap()
939        .iter()
940        .map(|name| format!(r#""{}""#, name))
941        .collect::<Vec<String>>()
942        .join(", ");
943    let pk_names_eq_str = pk_column_names
944        .as_ref()
945        .unwrap()
946        .iter()
947        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
948        .collect::<Vec<String>>()
949        .join(" AND ");
950    let all_column_names_set_str = all_column_names
951        .as_ref()
952        .unwrap()
953        .iter()
954        .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
955        .collect::<Vec<String>>()
956        .join(", ");
957    let all_column_names_str = all_column_names
958        .as_ref()
959        .unwrap()
960        .iter()
961        .map(|name| format!(r#""{}""#, name))
962        .collect::<Vec<String>>()
963        .join(", ");
964    let all_column_names_insert_str = all_column_names
965        .as_ref()
966        .unwrap()
967        .iter()
968        .map(|name| format!(r#"source."{}""#, name))
969        .collect::<Vec<String>>()
970        .join(", ");
971
972    format!(
973        r#"CREATE OR REPLACE TASK {task_name}
974WAREHOUSE = {warehouse}
975SCHEDULE = '{schedule_seconds} SECONDS'
976AS
977BEGIN
978    LET max_row_id STRING;
979
980    SELECT COALESCE(MAX("{snowflake_sink_row_id}"), '0') INTO :max_row_id
981    FROM {cdc_table_name};
982
983    MERGE INTO {target_table_name} AS target
984    USING (
985        SELECT *
986        FROM (
987            SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_names_str} ORDER BY "{snowflake_sink_row_id}" DESC) AS dedupe_id
988            FROM {cdc_table_name}
989            WHERE "{snowflake_sink_row_id}" <= :max_row_id
990        ) AS subquery
991        WHERE dedupe_id = 1
992    ) AS source
993    ON {pk_names_eq_str}
994    WHEN MATCHED AND source."{snowflake_sink_op}" IN (2, 4) THEN DELETE
995    WHEN MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN UPDATE SET {all_column_names_set_str}
996    WHEN NOT MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN INSERT ({all_column_names_str}) VALUES ({all_column_names_insert_str});
997
998    DELETE FROM {cdc_table_name}
999    WHERE "{snowflake_sink_row_id}" <= :max_row_id;
1000END;"#,
1001        task_name = full_task_name,
1002        warehouse = warehouse.as_ref().unwrap(),
1003        schedule_seconds = schedule_seconds,
1004        cdc_table_name = full_cdc_table_name,
1005        target_table_name = full_target_table_name,
1006        pk_names_str = pk_names_str,
1007        pk_names_eq_str = pk_names_eq_str,
1008        all_column_names_set_str = all_column_names_set_str,
1009        all_column_names_str = all_column_names_str,
1010        all_column_names_insert_str = all_column_names_insert_str,
1011        snowflake_sink_row_id = SNOWFLAKE_SINK_ROW_ID,
1012        snowflake_sink_op = SNOWFLAKE_SINK_OP,
1013    )
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018    use super::*;
1019    use crate::sink::jdbc_jni_client::normalize_sql;
1020
1021    #[test]
1022    fn test_snowflake_sink_commit_coordinator() {
1023        let snowflake_task_context = SnowflakeTaskContext {
1024            task_name: Some("test_task".to_owned()),
1025            cdc_table_name: Some("test_cdc_table".to_owned()),
1026            target_table_name: "test_target_table".to_owned(),
1027            schedule_seconds: 3600,
1028            warehouse: Some("test_warehouse".to_owned()),
1029            pk_column_names: Some(vec!["v1".to_owned()]),
1030            all_column_names: Some(vec!["v1".to_owned(), "v2".to_owned()]),
1031            database: "test_db".to_owned(),
1032            schema_name: "test_schema".to_owned(),
1033            schema: Schema { fields: vec![] },
1034            stage: None,
1035            pipe_name: None,
1036        };
1037        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1038        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task"
1039WAREHOUSE = test_warehouse
1040SCHEDULE = '3600 SECONDS'
1041AS
1042BEGIN
1043    LET max_row_id STRING;
1044
1045    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1046    FROM "test_db"."test_schema"."test_cdc_table";
1047
1048    MERGE INTO "test_db"."test_schema"."test_target_table" AS target
1049    USING (
1050        SELECT *
1051        FROM (
1052            SELECT *, ROW_NUMBER() OVER (PARTITION BY "v1" ORDER BY "__row_id" DESC) AS dedupe_id
1053            FROM "test_db"."test_schema"."test_cdc_table"
1054            WHERE "__row_id" <= :max_row_id
1055        ) AS subquery
1056        WHERE dedupe_id = 1
1057    ) AS source
1058    ON target."v1" = source."v1"
1059    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1060    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."v1" = source."v1", target."v2" = source."v2"
1061    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("v1", "v2") VALUES (source."v1", source."v2");
1062
1063    DELETE FROM "test_db"."test_schema"."test_cdc_table"
1064    WHERE "__row_id" <= :max_row_id;
1065END;"#;
1066        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1067    }
1068
1069    #[test]
1070    fn test_snowflake_sink_commit_coordinator_multi_pk() {
1071        let snowflake_task_context = SnowflakeTaskContext {
1072            task_name: Some("test_task_multi_pk".to_owned()),
1073            cdc_table_name: Some("cdc_multi_pk".to_owned()),
1074            target_table_name: "target_multi_pk".to_owned(),
1075            schedule_seconds: 300,
1076            warehouse: Some("multi_pk_warehouse".to_owned()),
1077            pk_column_names: Some(vec!["id1".to_owned(), "id2".to_owned()]),
1078            all_column_names: Some(vec!["id1".to_owned(), "id2".to_owned(), "val".to_owned()]),
1079            database: "test_db".to_owned(),
1080            schema_name: "test_schema".to_owned(),
1081            schema: Schema { fields: vec![] },
1082            stage: None,
1083            pipe_name: None,
1084        };
1085        let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1086        let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task_multi_pk"
1087WAREHOUSE = multi_pk_warehouse
1088SCHEDULE = '300 SECONDS'
1089AS
1090BEGIN
1091    LET max_row_id STRING;
1092
1093    SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1094    FROM "test_db"."test_schema"."cdc_multi_pk";
1095
1096    MERGE INTO "test_db"."test_schema"."target_multi_pk" AS target
1097    USING (
1098        SELECT *
1099        FROM (
1100            SELECT *, ROW_NUMBER() OVER (PARTITION BY "id1", "id2" ORDER BY "__row_id" DESC) AS dedupe_id
1101            FROM "test_db"."test_schema"."cdc_multi_pk"
1102            WHERE "__row_id" <= :max_row_id
1103        ) AS subquery
1104        WHERE dedupe_id = 1
1105    ) AS source
1106    ON target."id1" = source."id1" AND target."id2" = source."id2"
1107    WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1108    WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."id1" = source."id1", target."id2" = source."id2", target."val" = source."val"
1109    WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("id1", "id2", "val") VALUES (source."id1", source."id2", source."val");
1110
1111    DELETE FROM "test_db"."test_schema"."cdc_multi_pk"
1112    WHERE "__row_id" <= :max_row_id;
1113END;"#;
1114        assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1115    }
1116}