1use core::num::NonZeroU64;
16use std::collections::BTreeMap;
17
18use anyhow::anyhow;
19use phf::{Set, phf_set};
20use risingwave_common::array::StreamChunk;
21use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
22use risingwave_common::types::DataType;
23use risingwave_pb::connector_service::{SinkMetadata, sink_metadata};
24use sea_orm::DatabaseConnection;
25use serde::Deserialize;
26use serde_with::{DisplayFromStr, serde_as};
27use thiserror_ext::AsReport;
28use tokio::sync::mpsc::UnboundedSender;
29use tonic::async_trait;
30use with_options::WithOptions;
31
32use crate::connector_common::IcebergSinkCompactionUpdate;
33use crate::enforce_secret::EnforceSecret;
34use crate::sink::coordinate::CoordinatedLogSinker;
35use crate::sink::decouple_checkpoint_log_sink::default_commit_checkpoint_interval;
36use crate::sink::file_sink::s3::S3Common;
37use crate::sink::jdbc_jni_client::{self, JdbcJniClient};
38use crate::sink::remote::CoordinatedRemoteSinkWriter;
39use crate::sink::snowflake_redshift::{AugmentedChunk, SnowflakeRedshiftSinkS3Writer};
40use crate::sink::writer::SinkWriter;
41use crate::sink::{
42 Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, Sink, SinkCommitCoordinator,
43 SinkCommittedEpochSubscriber, SinkError, SinkParam, SinkWriterMetrics, SinkWriterParam,
44};
45
46pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2";
47pub const SNOWFLAKE_SINK_ROW_ID: &str = "__row_id";
48pub const SNOWFLAKE_SINK_OP: &str = "__op";
49
50#[serde_as]
51#[derive(Debug, Clone, Deserialize, WithOptions)]
52pub struct SnowflakeV2Config {
53 #[serde(rename = "type")]
54 pub r#type: String,
55
56 #[serde(rename = "intermediate.table.name")]
57 pub snowflake_cdc_table_name: Option<String>,
58
59 #[serde(rename = "table.name")]
60 pub snowflake_target_table_name: Option<String>,
61
62 #[serde(rename = "database")]
63 pub snowflake_database: Option<String>,
64
65 #[serde(rename = "schema")]
66 pub snowflake_schema: Option<String>,
67
68 #[serde(default = "default_schedule")]
69 #[serde(rename = "write.target.interval.seconds")]
70 #[serde_as(as = "DisplayFromStr")]
71 pub snowflake_schedule_seconds: u64,
72
73 #[serde(rename = "warehouse")]
74 pub snowflake_warehouse: Option<String>,
75
76 #[serde(rename = "jdbc.url")]
77 pub jdbc_url: Option<String>,
78
79 #[serde(rename = "username")]
80 pub username: Option<String>,
81
82 #[serde(rename = "password")]
83 pub password: Option<String>,
84
85 #[serde(default = "default_commit_checkpoint_interval")]
87 #[serde_as(as = "DisplayFromStr")]
88 #[with_option(allow_alter_on_fly)]
89 pub commit_checkpoint_interval: u64,
90
91 #[serde(default)]
94 #[serde(rename = "auto.schema.change")]
95 #[serde_as(as = "DisplayFromStr")]
96 pub auto_schema_change: bool,
97
98 #[serde(default)]
99 #[serde(rename = "create_table_if_not_exists")]
100 #[serde_as(as = "DisplayFromStr")]
101 pub create_table_if_not_exists: bool,
102
103 #[serde(default = "default_with_s3")]
104 #[serde(rename = "with_s3")]
105 #[serde_as(as = "DisplayFromStr")]
106 pub with_s3: bool,
107
108 #[serde(flatten)]
109 pub s3_inner: Option<S3Common>,
110
111 #[serde(rename = "stage")]
112 pub stage: Option<String>,
113}
114
115fn default_schedule() -> u64 {
116 3600 }
118
119fn default_with_s3() -> bool {
120 true
121}
122
123impl SnowflakeV2Config {
124 pub fn from_btreemap(properties: &BTreeMap<String, String>) -> Result<Self> {
125 let config =
126 serde_json::from_value::<SnowflakeV2Config>(serde_json::to_value(properties).unwrap())
127 .map_err(|e| SinkError::Config(anyhow!(e)))?;
128 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
129 return Err(SinkError::Config(anyhow!(
130 "`{}` must be {}, or {}",
131 SINK_TYPE_OPTION,
132 SINK_TYPE_APPEND_ONLY,
133 SINK_TYPE_UPSERT
134 )));
135 }
136 Ok(config)
137 }
138
139 pub fn build_snowflake_task_ctx_jdbc_client(
140 &self,
141 is_append_only: bool,
142 schema: &Schema,
143 pk_indices: &Vec<usize>,
144 ) -> Result<Option<(SnowflakeTaskContext, JdbcJniClient)>> {
145 if !self.auto_schema_change && is_append_only && !self.create_table_if_not_exists {
146 return Ok(None);
148 }
149 let target_table_name = self
150 .snowflake_target_table_name
151 .clone()
152 .ok_or(SinkError::Config(anyhow!("table.name is required")))?;
153 let database = self
154 .snowflake_database
155 .clone()
156 .ok_or(SinkError::Config(anyhow!("database is required")))?
157 .to_owned();
158 let schema_name = self
159 .snowflake_schema
160 .clone()
161 .ok_or(SinkError::Config(anyhow!("schema is required")))?
162 .to_owned();
163 let mut snowflake_task_ctx = SnowflakeTaskContext {
164 target_table_name: target_table_name.clone(),
165 database,
166 schema_name,
167 schema: schema.clone(),
168 ..Default::default()
169 };
170
171 let jdbc_url = self
172 .jdbc_url
173 .clone()
174 .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?
175 .to_owned();
176 let username = self
177 .username
178 .clone()
179 .ok_or(SinkError::Config(anyhow!("username is required")))?;
180 let password = self
181 .password
182 .clone()
183 .ok_or(SinkError::Config(anyhow!("password is required")))?;
184 let jdbc_url = format!("{}?user={}&password={}", jdbc_url, username, password);
185 let client = JdbcJniClient::new(jdbc_url)?;
186
187 if self.with_s3 {
188 let stage = self
189 .stage
190 .clone()
191 .ok_or(SinkError::Config(anyhow!("stage is required")))?;
192 snowflake_task_ctx.stage = Some(stage);
193 snowflake_task_ctx.pipe_name = Some(format!("{}_pipe", target_table_name));
194 }
195 if !is_append_only {
196 let cdc_table_name = self
197 .snowflake_cdc_table_name
198 .clone()
199 .ok_or(SinkError::Config(anyhow!(
200 "intermediate.table.name is required"
201 )))?;
202 snowflake_task_ctx.cdc_table_name = Some(cdc_table_name.clone());
203 snowflake_task_ctx.schedule_seconds = self.snowflake_schedule_seconds;
204 snowflake_task_ctx.warehouse = Some(
205 self.snowflake_warehouse
206 .clone()
207 .ok_or(SinkError::Config(anyhow!("warehouse is required")))?,
208 );
209 let pk_column_names: Vec<_> = schema
210 .fields
211 .iter()
212 .enumerate()
213 .filter(|(index, _)| pk_indices.contains(index))
214 .map(|(_, field)| field.name.clone())
215 .collect();
216 if pk_column_names.is_empty() {
217 return Err(SinkError::Config(anyhow!(
218 "Primary key columns not found. Please set the `primary_key` column in the sink properties, or ensure that the sink contains the primary key columns from the upstream."
219 )));
220 }
221 snowflake_task_ctx.pk_column_names = Some(pk_column_names);
222 snowflake_task_ctx.all_column_names = Some(
223 schema
224 .fields
225 .iter()
226 .map(|field| field.name.clone())
227 .collect(),
228 );
229 snowflake_task_ctx.task_name = Some(format!(
230 "rw_snowflake_sink_from_{cdc_table_name}_to_{target_table_name}"
231 ));
232 }
233 Ok(Some((snowflake_task_ctx, client)))
234 }
235}
236
237impl EnforceSecret for SnowflakeV2Config {
238 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
239 "username",
240 "password",
241 "jdbc.url",
242 };
243}
244
245#[derive(Clone, Debug)]
246pub struct SnowflakeV2Sink {
247 config: SnowflakeV2Config,
248 schema: Schema,
249 pk_indices: Vec<usize>,
250 is_append_only: bool,
251 param: SinkParam,
252}
253
254impl EnforceSecret for SnowflakeV2Sink {
255 fn enforce_secret<'a>(
256 prop_iter: impl Iterator<Item = &'a str>,
257 ) -> crate::sink::ConnectorResult<()> {
258 for prop in prop_iter {
259 SnowflakeV2Config::enforce_one(prop)?;
260 }
261 Ok(())
262 }
263}
264
265impl TryFrom<SinkParam> for SnowflakeV2Sink {
266 type Error = SinkError;
267
268 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
269 let schema = param.schema();
270 let config = SnowflakeV2Config::from_btreemap(¶m.properties)?;
271 let is_append_only = param.sink_type.is_append_only();
272 let pk_indices = param.downstream_pk.clone();
273 Ok(Self {
274 config,
275 schema,
276 pk_indices,
277 is_append_only,
278 param,
279 })
280 }
281}
282
283impl Sink for SnowflakeV2Sink {
284 type Coordinator = SnowflakeSinkCommitter;
285 type LogSinker = CoordinatedLogSinker<SnowflakeSinkWriter>;
286
287 const SINK_NAME: &'static str = SNOWFLAKE_SINK_V2;
288
289 async fn validate(&self) -> Result<()> {
290 risingwave_common::license::Feature::SnowflakeSink
291 .check_available()
292 .map_err(|e| anyhow::anyhow!(e))?;
293 if let Some((snowflake_task_ctx, client)) =
294 self.config.build_snowflake_task_ctx_jdbc_client(
295 self.is_append_only,
296 &self.schema,
297 &self.pk_indices,
298 )?
299 {
300 let client = SnowflakeJniClient::new(client, snowflake_task_ctx);
301 client.execute_create_table().await?;
302 }
303
304 Ok(())
305 }
306
307 fn support_schema_change() -> bool {
308 true
309 }
310
311 fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
312 SnowflakeV2Config::from_btreemap(config)?;
313 Ok(())
314 }
315
316 async fn new_log_sinker(
317 &self,
318 writer_param: crate::sink::SinkWriterParam,
319 ) -> Result<Self::LogSinker> {
320 let writer = SnowflakeSinkWriter::new(
321 self.config.clone(),
322 self.is_append_only,
323 writer_param.clone(),
324 self.param.clone(),
325 )
326 .await?;
327
328 let commit_checkpoint_interval =
329 NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
330 "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
331 );
332
333 CoordinatedLogSinker::new(
334 &writer_param,
335 self.param.clone(),
336 writer,
337 commit_checkpoint_interval,
338 )
339 .await
340 }
341
342 fn is_coordinated_sink(&self) -> bool {
343 true
344 }
345
346 async fn new_coordinator(
347 &self,
348 _db: DatabaseConnection,
349 _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
350 ) -> Result<Self::Coordinator> {
351 let coordinator = SnowflakeSinkCommitter::new(
352 self.config.clone(),
353 &self.schema,
354 &self.pk_indices,
355 self.is_append_only,
356 )?;
357 Ok(coordinator)
358 }
359}
360
361pub enum SnowflakeSinkWriter {
362 S3(SnowflakeRedshiftSinkS3Writer),
363 Jdbc(SnowflakeSinkJdbcWriter),
364}
365
366impl SnowflakeSinkWriter {
367 pub async fn new(
368 config: SnowflakeV2Config,
369 is_append_only: bool,
370 writer_param: SinkWriterParam,
371 param: SinkParam,
372 ) -> Result<Self> {
373 let schema = param.schema();
374 if config.with_s3 {
375 let executor_id = writer_param.executor_id;
376 let s3_writer = SnowflakeRedshiftSinkS3Writer::new(
377 config.s3_inner.ok_or_else(|| {
378 SinkError::Config(anyhow!(
379 "S3 configuration is required for Snowflake S3 sink"
380 ))
381 })?,
382 schema,
383 is_append_only,
384 executor_id,
385 config.snowflake_target_table_name,
386 )?;
387 Ok(Self::S3(s3_writer))
388 } else {
389 let jdbc_writer =
390 SnowflakeSinkJdbcWriter::new(config, is_append_only, writer_param, param).await?;
391 Ok(Self::Jdbc(jdbc_writer))
392 }
393 }
394}
395
396#[async_trait]
397impl SinkWriter for SnowflakeSinkWriter {
398 type CommitMetadata = Option<SinkMetadata>;
399
400 async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
401 match self {
402 Self::S3(writer) => writer.begin_epoch(epoch),
403 Self::Jdbc(writer) => writer.begin_epoch(epoch).await,
404 }
405 }
406
407 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
408 match self {
409 Self::S3(writer) => writer.write_batch(chunk).await,
410 Self::Jdbc(writer) => writer.write_batch(chunk).await,
411 }
412 }
413
414 async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
415 match self {
416 Self::S3(writer) => {
417 writer.barrier(is_checkpoint).await?;
418 }
419 Self::Jdbc(writer) => {
420 writer.barrier(is_checkpoint).await?;
421 }
422 }
423 Ok(Some(SinkMetadata {
424 metadata: Some(sink_metadata::Metadata::Serialized(
425 risingwave_pb::connector_service::sink_metadata::SerializedMetadata {
426 metadata: vec![],
427 },
428 )),
429 }))
430 }
431
432 async fn abort(&mut self) -> Result<()> {
433 if let Self::Jdbc(writer) = self {
434 writer.abort().await
435 } else {
436 Ok(())
437 }
438 }
439}
440
441pub struct SnowflakeSinkJdbcWriter {
442 augmented_row: AugmentedChunk,
443 jdbc_sink_writer: CoordinatedRemoteSinkWriter,
444}
445
446impl SnowflakeSinkJdbcWriter {
447 pub async fn new(
448 config: SnowflakeV2Config,
449 is_append_only: bool,
450 writer_param: SinkWriterParam,
451 mut param: SinkParam,
452 ) -> Result<Self> {
453 let metrics = SinkWriterMetrics::new(&writer_param);
454 let properties = ¶m.properties;
455 let column_descs = &mut param.columns;
456 let full_table_name = if is_append_only {
457 format!(
458 r#""{}"."{}"."{}""#,
459 config.snowflake_database.clone().unwrap_or_default(),
460 config.snowflake_schema.clone().unwrap_or_default(),
461 config
462 .snowflake_target_table_name
463 .clone()
464 .unwrap_or_default()
465 )
466 } else {
467 let max_column_id = column_descs
468 .iter()
469 .map(|column| column.column_id.get_id())
470 .max()
471 .unwrap_or(0);
472 (*column_descs).push(ColumnDesc::named(
473 SNOWFLAKE_SINK_ROW_ID,
474 ColumnId::new(max_column_id + 1),
475 DataType::Varchar,
476 ));
477 (*column_descs).push(ColumnDesc::named(
478 SNOWFLAKE_SINK_OP,
479 ColumnId::new(max_column_id + 2),
480 DataType::Int32,
481 ));
482 format!(
483 r#""{}"."{}"."{}""#,
484 config.snowflake_database.clone().unwrap_or_default(),
485 config.snowflake_schema.clone().unwrap_or_default(),
486 config.snowflake_cdc_table_name.clone().unwrap_or_default()
487 )
488 };
489 let new_properties = BTreeMap::from([
490 ("table.name".to_owned(), full_table_name),
491 ("connector".to_owned(), "snowflake_v2".to_owned()),
492 (
493 "jdbc.url".to_owned(),
494 config.jdbc_url.clone().unwrap_or_default(),
495 ),
496 ("type".to_owned(), "append-only".to_owned()),
497 (
498 "user".to_owned(),
499 config.username.clone().unwrap_or_default(),
500 ),
501 (
502 "password".to_owned(),
503 config.password.clone().unwrap_or_default(),
504 ),
505 (
506 "primary_key".to_owned(),
507 properties.get("primary_key").cloned().unwrap_or_default(),
508 ),
509 (
510 "schema.name".to_owned(),
511 config.snowflake_schema.clone().unwrap_or_default(),
512 ),
513 (
514 "database.name".to_owned(),
515 config.snowflake_database.clone().unwrap_or_default(),
516 ),
517 ]);
518 param.properties = new_properties;
519
520 let jdbc_sink_writer =
521 CoordinatedRemoteSinkWriter::new(param.clone(), metrics.clone()).await?;
522 Ok(Self {
523 augmented_row: AugmentedChunk::new(0, is_append_only),
524 jdbc_sink_writer,
525 })
526 }
527}
528
529impl SnowflakeSinkJdbcWriter {
530 async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
531 self.augmented_row.reset_epoch(epoch);
532 self.jdbc_sink_writer.begin_epoch(epoch).await?;
533 Ok(())
534 }
535
536 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
537 let chunk = self.augmented_row.augmented_chunk(chunk)?;
538 self.jdbc_sink_writer.write_batch(chunk).await?;
539 Ok(())
540 }
541
542 async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
543 self.jdbc_sink_writer.barrier(is_checkpoint).await?;
544 Ok(())
545 }
546
547 async fn abort(&mut self) -> Result<()> {
548 self.jdbc_sink_writer.abort().await?;
550 Ok(())
551 }
552}
553
554#[derive(Default)]
555pub struct SnowflakeTaskContext {
556 pub target_table_name: String,
558 pub database: String,
559 pub schema_name: String,
560 pub schema: Schema,
561
562 pub task_name: Option<String>,
564 pub cdc_table_name: Option<String>,
565 pub schedule_seconds: u64,
566 pub warehouse: Option<String>,
567 pub pk_column_names: Option<Vec<String>>,
568 pub all_column_names: Option<Vec<String>>,
569
570 pub stage: Option<String>,
572 pub pipe_name: Option<String>,
573}
574pub struct SnowflakeSinkCommitter {
575 client: Option<SnowflakeJniClient>,
576}
577
578impl SnowflakeSinkCommitter {
579 pub fn new(
580 config: SnowflakeV2Config,
581 schema: &Schema,
582 pk_indices: &Vec<usize>,
583 is_append_only: bool,
584 ) -> Result<Self> {
585 let client = if let Some((snowflake_task_ctx, client)) =
586 config.build_snowflake_task_ctx_jdbc_client(is_append_only, schema, pk_indices)?
587 {
588 Some(SnowflakeJniClient::new(client, snowflake_task_ctx))
589 } else {
590 None
591 };
592 Ok(Self { client })
593 }
594}
595
596#[async_trait]
597impl SinkCommitCoordinator for SnowflakeSinkCommitter {
598 async fn init(&mut self, _subscriber: SinkCommittedEpochSubscriber) -> Result<Option<u64>> {
599 if let Some(client) = &self.client {
600 client.execute_create_pipe().await?;
602 client.execute_create_merge_into_task().await?;
603 }
604 Ok(None)
605 }
606
607 async fn commit(
608 &mut self,
609 _epoch: u64,
610 _metadata: Vec<SinkMetadata>,
611 add_columns: Option<Vec<Field>>,
612 ) -> Result<()> {
613 let client = self.client.as_mut().ok_or_else(|| {
614 SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
615 })?;
616 client.execute_flush_pipe().await?;
617
618 if let Some(add_columns) = add_columns {
619 client
620 .execute_alter_add_columns(
621 &add_columns
622 .iter()
623 .map(|f| (f.name.clone(), f.data_type.to_string()))
624 .collect::<Vec<_>>(),
625 )
626 .await?;
627 }
628 Ok(())
629 }
630}
631
632impl Drop for SnowflakeSinkCommitter {
633 fn drop(&mut self) {
634 if let Some(client) = self.client.take() {
635 tokio::spawn(async move {
636 client.execute_drop_task().await.ok();
637 });
638 }
639 }
640}
641
642pub struct SnowflakeJniClient {
643 jdbc_client: JdbcJniClient,
644 snowflake_task_context: SnowflakeTaskContext,
645}
646
647impl SnowflakeJniClient {
648 pub fn new(jdbc_client: JdbcJniClient, snowflake_task_context: SnowflakeTaskContext) -> Self {
649 Self {
650 jdbc_client,
651 snowflake_task_context,
652 }
653 }
654
655 pub async fn execute_alter_add_columns(
656 &mut self,
657 columns: &Vec<(String, String)>,
658 ) -> Result<()> {
659 self.execute_drop_task().await?;
660 if let Some(names) = self.snowflake_task_context.all_column_names.as_mut() {
661 names.extend(columns.iter().map(|(name, _)| name.clone()));
662 }
663 if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
664 let alter_add_column_cdc_table_sql = build_alter_add_column_sql(
665 cdc_table_name,
666 &self.snowflake_task_context.database,
667 &self.snowflake_task_context.schema_name,
668 columns,
669 );
670 self.jdbc_client
671 .execute_sql_sync(vec![alter_add_column_cdc_table_sql])
672 .await?;
673 }
674
675 let alter_add_column_target_table_sql = build_alter_add_column_sql(
676 &self.snowflake_task_context.target_table_name,
677 &self.snowflake_task_context.database,
678 &self.snowflake_task_context.schema_name,
679 columns,
680 );
681 self.jdbc_client
682 .execute_sql_sync(vec![alter_add_column_target_table_sql])
683 .await?;
684
685 self.execute_create_merge_into_task().await?;
686 Ok(())
687 }
688
689 pub async fn execute_create_merge_into_task(&self) -> Result<()> {
690 if self.snowflake_task_context.task_name.is_some() {
691 let create_task_sql = build_create_merge_into_task_sql(&self.snowflake_task_context);
692 let start_task_sql = build_start_task_sql(&self.snowflake_task_context);
693 self.jdbc_client
694 .execute_sql_sync(vec![create_task_sql])
695 .await?;
696 self.jdbc_client
697 .execute_sql_sync(vec![start_task_sql])
698 .await?;
699 }
700 Ok(())
701 }
702
703 pub async fn execute_drop_task(&self) -> Result<()> {
704 if self.snowflake_task_context.task_name.is_some() {
705 let sql = build_drop_task_sql(&self.snowflake_task_context);
706 if let Err(e) = self.jdbc_client.execute_sql_sync(vec![sql]).await {
707 tracing::error!(
708 "Failed to drop Snowflake sink task {:?}: {:?}",
709 self.snowflake_task_context.task_name,
710 e.as_report()
711 );
712 } else {
713 tracing::info!(
714 "Snowflake sink task {:?} dropped",
715 self.snowflake_task_context.task_name
716 );
717 }
718 }
719 Ok(())
720 }
721
722 pub async fn execute_create_table(&self) -> Result<()> {
723 let create_target_table_sql = build_create_table_sql(
725 &self.snowflake_task_context.target_table_name,
726 &self.snowflake_task_context.database,
727 &self.snowflake_task_context.schema_name,
728 &self.snowflake_task_context.schema,
729 false,
730 )?;
731 self.jdbc_client
732 .execute_sql_sync(vec![create_target_table_sql])
733 .await?;
734 if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
735 let create_cdc_table_sql = build_create_table_sql(
736 cdc_table_name,
737 &self.snowflake_task_context.database,
738 &self.snowflake_task_context.schema_name,
739 &self.snowflake_task_context.schema,
740 true,
741 )?;
742 self.jdbc_client
743 .execute_sql_sync(vec![create_cdc_table_sql])
744 .await?;
745 }
746 Ok(())
747 }
748
749 pub async fn execute_create_pipe(&self) -> Result<()> {
750 if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
751 let table_name =
752 if let Some(table_name) = self.snowflake_task_context.cdc_table_name.as_ref() {
753 table_name
754 } else {
755 &self.snowflake_task_context.target_table_name
756 };
757 let create_pipe_sql = build_create_pipe_sql(
758 table_name,
759 &self.snowflake_task_context.database,
760 &self.snowflake_task_context.schema_name,
761 self.snowflake_task_context.stage.as_ref().ok_or_else(|| {
762 SinkError::Config(anyhow!("snowflake.stage is required for S3 writer"))
763 })?,
764 pipe_name,
765 &self.snowflake_task_context.target_table_name,
766 );
767 self.jdbc_client
768 .execute_sql_sync(vec![create_pipe_sql])
769 .await?;
770 }
771 Ok(())
772 }
773
774 pub async fn execute_flush_pipe(&self) -> Result<()> {
775 if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
776 let flush_pipe_sql = build_flush_pipe_sql(
777 &self.snowflake_task_context.database,
778 &self.snowflake_task_context.schema_name,
779 pipe_name,
780 );
781 self.jdbc_client
782 .execute_sql_sync(vec![flush_pipe_sql])
783 .await?;
784 }
785 Ok(())
786 }
787}
788
789fn build_create_table_sql(
790 table_name: &str,
791 database: &str,
792 schema_name: &str,
793 schema: &Schema,
794 need_op_and_row_id: bool,
795) -> Result<String> {
796 let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema_name, table_name);
797 let mut columns: Vec<String> = schema
798 .fields
799 .iter()
800 .map(|field| {
801 let data_type = convert_snowflake_data_type(&field.data_type)?;
802 Ok(format!(r#""{}" {}"#, field.name, data_type))
803 })
804 .collect::<Result<Vec<String>>>()?;
805 if need_op_and_row_id {
806 columns.push(format!(r#""{}" STRING"#, SNOWFLAKE_SINK_ROW_ID));
807 columns.push(format!(r#""{}" INT"#, SNOWFLAKE_SINK_OP));
808 }
809 let columns_str = columns.join(", ");
810 Ok(format!(
811 "CREATE TABLE IF NOT EXISTS {} ({}) ENABLE_SCHEMA_EVOLUTION = true",
812 full_table_name, columns_str
813 ))
814}
815
816fn convert_snowflake_data_type(data_type: &DataType) -> Result<String> {
817 let data_type = match data_type {
818 DataType::Int16 => "SMALLINT".to_owned(),
819 DataType::Int32 => "INTEGER".to_owned(),
820 DataType::Int64 => "BIGINT".to_owned(),
821 DataType::Float32 => "FLOAT4".to_owned(),
822 DataType::Float64 => "FLOAT8".to_owned(),
823 DataType::Boolean => "BOOLEAN".to_owned(),
824 DataType::Varchar => "STRING".to_owned(),
825 DataType::Date => "DATE".to_owned(),
826 DataType::Timestamp => "TIMESTAMP".to_owned(),
827 DataType::Timestamptz => "TIMESTAMP_TZ".to_owned(),
828 DataType::Jsonb => "STRING".to_owned(),
829 DataType::Decimal => "DECIMAL".to_owned(),
830 DataType::Bytea => "BINARY".to_owned(),
831 DataType::Time => "TIME".to_owned(),
832 _ => {
833 return Err(SinkError::Config(anyhow!(
834 "Dont support auto create table for datatype: {}",
835 data_type
836 )));
837 }
838 };
839 Ok(data_type)
840}
841
842fn build_create_pipe_sql(
843 table_name: &str,
844 database: &str,
845 schema: &str,
846 stage: &str,
847 pipe_name: &str,
848 target_table_name: &str,
849) -> String {
850 let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
851 let stage = format!(
852 r#""{}"."{}"."{}"/{}"#,
853 database, schema, stage, target_table_name
854 );
855 let table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
856 format!(
857 "CREATE OR REPLACE PIPE {} AUTO_INGEST = FALSE AS COPY INTO {} FROM @{} MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE FILE_FORMAT = (type = 'JSON');",
858 pipe_name, table_name, stage
859 )
860}
861
862fn build_flush_pipe_sql(database: &str, schema: &str, pipe_name: &str) -> String {
863 let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
864 format!("ALTER PIPE {} REFRESH;", pipe_name,)
865}
866
867fn build_alter_add_column_sql(
868 table_name: &str,
869 database: &str,
870 schema: &str,
871 columns: &Vec<(String, String)>,
872) -> String {
873 let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
874 jdbc_jni_client::build_alter_add_column_sql(&full_table_name, columns, true)
875}
876
877fn build_start_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
878 let SnowflakeTaskContext {
879 task_name,
880 database,
881 schema_name: schema,
882 ..
883 } = snowflake_task_context;
884 let full_task_name = format!(
885 r#""{}"."{}"."{}""#,
886 database,
887 schema,
888 task_name.as_ref().unwrap()
889 );
890 format!("ALTER TASK {} RESUME", full_task_name)
891}
892
893fn build_drop_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
894 let SnowflakeTaskContext {
895 task_name,
896 database,
897 schema_name: schema,
898 ..
899 } = snowflake_task_context;
900 let full_task_name = format!(
901 r#""{}"."{}"."{}""#,
902 database,
903 schema,
904 task_name.as_ref().unwrap()
905 );
906 format!("DROP TASK IF EXISTS {}", full_task_name)
907}
908
909fn build_create_merge_into_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
910 let SnowflakeTaskContext {
911 task_name,
912 cdc_table_name,
913 target_table_name,
914 schedule_seconds,
915 warehouse,
916 pk_column_names,
917 all_column_names,
918 database,
919 schema_name,
920 ..
921 } = snowflake_task_context;
922 let full_task_name = format!(
923 r#""{}"."{}"."{}""#,
924 database,
925 schema_name,
926 task_name.as_ref().unwrap()
927 );
928 let full_cdc_table_name = format!(
929 r#""{}"."{}"."{}""#,
930 database,
931 schema_name,
932 cdc_table_name.as_ref().unwrap()
933 );
934 let full_target_table_name = format!(
935 r#""{}"."{}"."{}""#,
936 database, schema_name, target_table_name
937 );
938
939 let pk_names_str = pk_column_names
940 .as_ref()
941 .unwrap()
942 .iter()
943 .map(|name| format!(r#""{}""#, name))
944 .collect::<Vec<String>>()
945 .join(", ");
946 let pk_names_eq_str = pk_column_names
947 .as_ref()
948 .unwrap()
949 .iter()
950 .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
951 .collect::<Vec<String>>()
952 .join(" AND ");
953 let all_column_names_set_str = all_column_names
954 .as_ref()
955 .unwrap()
956 .iter()
957 .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
958 .collect::<Vec<String>>()
959 .join(", ");
960 let all_column_names_str = all_column_names
961 .as_ref()
962 .unwrap()
963 .iter()
964 .map(|name| format!(r#""{}""#, name))
965 .collect::<Vec<String>>()
966 .join(", ");
967 let all_column_names_insert_str = all_column_names
968 .as_ref()
969 .unwrap()
970 .iter()
971 .map(|name| format!(r#"source."{}""#, name))
972 .collect::<Vec<String>>()
973 .join(", ");
974
975 format!(
976 r#"CREATE OR REPLACE TASK {task_name}
977WAREHOUSE = {warehouse}
978SCHEDULE = '{schedule_seconds} SECONDS'
979AS
980BEGIN
981 LET max_row_id STRING;
982
983 SELECT COALESCE(MAX("{snowflake_sink_row_id}"), '0') INTO :max_row_id
984 FROM {cdc_table_name};
985
986 MERGE INTO {target_table_name} AS target
987 USING (
988 SELECT *
989 FROM (
990 SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_names_str} ORDER BY "{snowflake_sink_row_id}" DESC) AS dedupe_id
991 FROM {cdc_table_name}
992 WHERE "{snowflake_sink_row_id}" <= :max_row_id
993 ) AS subquery
994 WHERE dedupe_id = 1
995 ) AS source
996 ON {pk_names_eq_str}
997 WHEN MATCHED AND source."{snowflake_sink_op}" IN (2, 4) THEN DELETE
998 WHEN MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN UPDATE SET {all_column_names_set_str}
999 WHEN NOT MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN INSERT ({all_column_names_str}) VALUES ({all_column_names_insert_str});
1000
1001 DELETE FROM {cdc_table_name}
1002 WHERE "{snowflake_sink_row_id}" <= :max_row_id;
1003END;"#,
1004 task_name = full_task_name,
1005 warehouse = warehouse.as_ref().unwrap(),
1006 schedule_seconds = schedule_seconds,
1007 cdc_table_name = full_cdc_table_name,
1008 target_table_name = full_target_table_name,
1009 pk_names_str = pk_names_str,
1010 pk_names_eq_str = pk_names_eq_str,
1011 all_column_names_set_str = all_column_names_set_str,
1012 all_column_names_str = all_column_names_str,
1013 all_column_names_insert_str = all_column_names_insert_str,
1014 snowflake_sink_row_id = SNOWFLAKE_SINK_ROW_ID,
1015 snowflake_sink_op = SNOWFLAKE_SINK_OP,
1016 )
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021 use super::*;
1022 use crate::sink::jdbc_jni_client::normalize_sql;
1023
1024 #[test]
1025 fn test_snowflake_sink_commit_coordinator() {
1026 let snowflake_task_context = SnowflakeTaskContext {
1027 task_name: Some("test_task".to_owned()),
1028 cdc_table_name: Some("test_cdc_table".to_owned()),
1029 target_table_name: "test_target_table".to_owned(),
1030 schedule_seconds: 3600,
1031 warehouse: Some("test_warehouse".to_owned()),
1032 pk_column_names: Some(vec!["v1".to_owned()]),
1033 all_column_names: Some(vec!["v1".to_owned(), "v2".to_owned()]),
1034 database: "test_db".to_owned(),
1035 schema_name: "test_schema".to_owned(),
1036 schema: Schema { fields: vec![] },
1037 stage: None,
1038 pipe_name: None,
1039 };
1040 let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1041 let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task"
1042WAREHOUSE = test_warehouse
1043SCHEDULE = '3600 SECONDS'
1044AS
1045BEGIN
1046 LET max_row_id STRING;
1047
1048 SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1049 FROM "test_db"."test_schema"."test_cdc_table";
1050
1051 MERGE INTO "test_db"."test_schema"."test_target_table" AS target
1052 USING (
1053 SELECT *
1054 FROM (
1055 SELECT *, ROW_NUMBER() OVER (PARTITION BY "v1" ORDER BY "__row_id" DESC) AS dedupe_id
1056 FROM "test_db"."test_schema"."test_cdc_table"
1057 WHERE "__row_id" <= :max_row_id
1058 ) AS subquery
1059 WHERE dedupe_id = 1
1060 ) AS source
1061 ON target."v1" = source."v1"
1062 WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1063 WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."v1" = source."v1", target."v2" = source."v2"
1064 WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("v1", "v2") VALUES (source."v1", source."v2");
1065
1066 DELETE FROM "test_db"."test_schema"."test_cdc_table"
1067 WHERE "__row_id" <= :max_row_id;
1068END;"#;
1069 assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1070 }
1071
1072 #[test]
1073 fn test_snowflake_sink_commit_coordinator_multi_pk() {
1074 let snowflake_task_context = SnowflakeTaskContext {
1075 task_name: Some("test_task_multi_pk".to_owned()),
1076 cdc_table_name: Some("cdc_multi_pk".to_owned()),
1077 target_table_name: "target_multi_pk".to_owned(),
1078 schedule_seconds: 300,
1079 warehouse: Some("multi_pk_warehouse".to_owned()),
1080 pk_column_names: Some(vec!["id1".to_owned(), "id2".to_owned()]),
1081 all_column_names: Some(vec!["id1".to_owned(), "id2".to_owned(), "val".to_owned()]),
1082 database: "test_db".to_owned(),
1083 schema_name: "test_schema".to_owned(),
1084 schema: Schema { fields: vec![] },
1085 stage: None,
1086 pipe_name: None,
1087 };
1088 let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1089 let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task_multi_pk"
1090WAREHOUSE = multi_pk_warehouse
1091SCHEDULE = '300 SECONDS'
1092AS
1093BEGIN
1094 LET max_row_id STRING;
1095
1096 SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1097 FROM "test_db"."test_schema"."cdc_multi_pk";
1098
1099 MERGE INTO "test_db"."test_schema"."target_multi_pk" AS target
1100 USING (
1101 SELECT *
1102 FROM (
1103 SELECT *, ROW_NUMBER() OVER (PARTITION BY "id1", "id2" ORDER BY "__row_id" DESC) AS dedupe_id
1104 FROM "test_db"."test_schema"."cdc_multi_pk"
1105 WHERE "__row_id" <= :max_row_id
1106 ) AS subquery
1107 WHERE dedupe_id = 1
1108 ) AS source
1109 ON target."id1" = source."id1" AND target."id2" = source."id2"
1110 WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1111 WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."id1" = source."id1", target."id2" = source."id2", target."val" = source."val"
1112 WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("id1", "id2", "val") VALUES (source."id1", source."id2", source."val");
1113
1114 DELETE FROM "test_db"."test_schema"."cdc_multi_pk"
1115 WHERE "__row_id" <= :max_row_id;
1116END;"#;
1117 assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1118 }
1119}