1use core::num::NonZeroU64;
16use std::collections::BTreeMap;
17
18use anyhow::anyhow;
19use phf::{Set, phf_set};
20use risingwave_common::array::StreamChunk;
21use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
22use risingwave_common::types::DataType;
23use risingwave_pb::connector_service::{SinkMetadata, sink_metadata};
24use sea_orm::DatabaseConnection;
25use serde::Deserialize;
26use serde_with::{DisplayFromStr, serde_as};
27use thiserror_ext::AsReport;
28use tokio::sync::mpsc::UnboundedSender;
29use tonic::async_trait;
30use with_options::WithOptions;
31
32use crate::connector_common::IcebergSinkCompactionUpdate;
33use crate::enforce_secret::EnforceSecret;
34use crate::sink::coordinate::CoordinatedLogSinker;
35use crate::sink::decouple_checkpoint_log_sink::default_commit_checkpoint_interval;
36use crate::sink::file_sink::s3::S3Common;
37use crate::sink::jdbc_jni_client::{self, JdbcJniClient};
38use crate::sink::remote::CoordinatedRemoteSinkWriter;
39use crate::sink::snowflake_redshift::{AugmentedChunk, SnowflakeRedshiftSinkS3Writer};
40use crate::sink::writer::SinkWriter;
41use crate::sink::{
42 Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, Sink, SinkCommitCoordinator,
43 SinkCommittedEpochSubscriber, SinkError, SinkParam, SinkWriterMetrics, SinkWriterParam,
44};
45
46pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2";
47pub const SNOWFLAKE_SINK_ROW_ID: &str = "__row_id";
48pub const SNOWFLAKE_SINK_OP: &str = "__op";
49
50#[serde_as]
51#[derive(Debug, Clone, Deserialize, WithOptions)]
52pub struct SnowflakeV2Config {
53 #[serde(rename = "type")]
54 pub r#type: String,
55
56 #[serde(rename = "intermediate.table.name")]
57 pub snowflake_cdc_table_name: Option<String>,
58
59 #[serde(rename = "table.name")]
60 pub snowflake_target_table_name: Option<String>,
61
62 #[serde(rename = "database")]
63 pub snowflake_database: Option<String>,
64
65 #[serde(rename = "schema")]
66 pub snowflake_schema: Option<String>,
67
68 #[serde(default = "default_schedule")]
69 #[serde(rename = "write.target.interval.seconds")]
70 #[serde_as(as = "DisplayFromStr")]
71 pub snowflake_schedule_seconds: u64,
72
73 #[serde(rename = "warehouse")]
74 pub snowflake_warehouse: Option<String>,
75
76 #[serde(rename = "jdbc.url")]
77 pub jdbc_url: Option<String>,
78
79 #[serde(rename = "username")]
80 pub username: Option<String>,
81
82 #[serde(rename = "password")]
83 pub password: Option<String>,
84
85 #[serde(default = "default_commit_checkpoint_interval")]
87 #[serde_as(as = "DisplayFromStr")]
88 #[with_option(allow_alter_on_fly)]
89 pub commit_checkpoint_interval: u64,
90
91 #[serde(default)]
94 #[serde(rename = "auto.schema.change")]
95 #[serde_as(as = "DisplayFromStr")]
96 pub auto_schema_change: bool,
97
98 #[serde(default)]
99 #[serde(rename = "create_table_if_not_exists")]
100 #[serde_as(as = "DisplayFromStr")]
101 pub create_table_if_not_exists: bool,
102
103 #[serde(default = "default_with_s3")]
104 #[serde(rename = "with_s3")]
105 #[serde_as(as = "DisplayFromStr")]
106 pub with_s3: bool,
107
108 #[serde(flatten)]
109 pub s3_inner: Option<S3Common>,
110
111 #[serde(rename = "stage")]
112 pub stage: Option<String>,
113}
114
115fn default_schedule() -> u64 {
116 3600 }
118
119fn default_with_s3() -> bool {
120 true
121}
122
123impl SnowflakeV2Config {
124 pub fn from_btreemap(properties: &BTreeMap<String, String>) -> Result<Self> {
125 let config =
126 serde_json::from_value::<SnowflakeV2Config>(serde_json::to_value(properties).unwrap())
127 .map_err(|e| SinkError::Config(anyhow!(e)))?;
128 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
129 return Err(SinkError::Config(anyhow!(
130 "`{}` must be {}, or {}",
131 SINK_TYPE_OPTION,
132 SINK_TYPE_APPEND_ONLY,
133 SINK_TYPE_UPSERT
134 )));
135 }
136 Ok(config)
137 }
138
139 pub fn build_snowflake_task_ctx_jdbc_client(
140 &self,
141 is_append_only: bool,
142 schema: &Schema,
143 pk_indices: &Vec<usize>,
144 ) -> Result<Option<(SnowflakeTaskContext, JdbcJniClient)>> {
145 if !self.auto_schema_change && is_append_only && !self.create_table_if_not_exists {
146 return Ok(None);
148 }
149 let target_table_name = self
150 .snowflake_target_table_name
151 .clone()
152 .ok_or(SinkError::Config(anyhow!("table.name is required")))?;
153 let database = self
154 .snowflake_database
155 .clone()
156 .ok_or(SinkError::Config(anyhow!("database is required")))?;
157 let schema_name = self
158 .snowflake_schema
159 .clone()
160 .ok_or(SinkError::Config(anyhow!("schema is required")))?;
161 let mut snowflake_task_ctx = SnowflakeTaskContext {
162 target_table_name: target_table_name.clone(),
163 database,
164 schema_name,
165 schema: schema.clone(),
166 ..Default::default()
167 };
168
169 let jdbc_url = self
170 .jdbc_url
171 .clone()
172 .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?;
173 let username = self
174 .username
175 .clone()
176 .ok_or(SinkError::Config(anyhow!("username is required")))?;
177 let password = self
178 .password
179 .clone()
180 .ok_or(SinkError::Config(anyhow!("password is required")))?;
181 let jdbc_url = format!("{}?user={}&password={}", jdbc_url, username, password);
182 let client = JdbcJniClient::new(jdbc_url)?;
183
184 if self.with_s3 {
185 let stage = self
186 .stage
187 .clone()
188 .ok_or(SinkError::Config(anyhow!("stage is required")))?;
189 snowflake_task_ctx.stage = Some(stage);
190 snowflake_task_ctx.pipe_name = Some(format!("{}_pipe", target_table_name));
191 }
192 if !is_append_only {
193 let cdc_table_name = self
194 .snowflake_cdc_table_name
195 .clone()
196 .ok_or(SinkError::Config(anyhow!(
197 "intermediate.table.name is required"
198 )))?;
199 snowflake_task_ctx.cdc_table_name = Some(cdc_table_name.clone());
200 snowflake_task_ctx.schedule_seconds = self.snowflake_schedule_seconds;
201 snowflake_task_ctx.warehouse = Some(
202 self.snowflake_warehouse
203 .clone()
204 .ok_or(SinkError::Config(anyhow!("warehouse is required")))?,
205 );
206 let pk_column_names: Vec<_> = schema
207 .fields
208 .iter()
209 .enumerate()
210 .filter(|(index, _)| pk_indices.contains(index))
211 .map(|(_, field)| field.name.clone())
212 .collect();
213 if pk_column_names.is_empty() {
214 return Err(SinkError::Config(anyhow!(
215 "Primary key columns not found. Please set the `primary_key` column in the sink properties, or ensure that the sink contains the primary key columns from the upstream."
216 )));
217 }
218 snowflake_task_ctx.pk_column_names = Some(pk_column_names);
219 snowflake_task_ctx.all_column_names = Some(
220 schema
221 .fields
222 .iter()
223 .map(|field| field.name.clone())
224 .collect(),
225 );
226 snowflake_task_ctx.task_name = Some(format!(
227 "rw_snowflake_sink_from_{cdc_table_name}_to_{target_table_name}"
228 ));
229 }
230 Ok(Some((snowflake_task_ctx, client)))
231 }
232}
233
234impl EnforceSecret for SnowflakeV2Config {
235 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
236 "username",
237 "password",
238 "jdbc.url",
239 };
240}
241
242#[derive(Clone, Debug)]
243pub struct SnowflakeV2Sink {
244 config: SnowflakeV2Config,
245 schema: Schema,
246 pk_indices: Vec<usize>,
247 is_append_only: bool,
248 param: SinkParam,
249}
250
251impl EnforceSecret for SnowflakeV2Sink {
252 fn enforce_secret<'a>(
253 prop_iter: impl Iterator<Item = &'a str>,
254 ) -> crate::sink::ConnectorResult<()> {
255 for prop in prop_iter {
256 SnowflakeV2Config::enforce_one(prop)?;
257 }
258 Ok(())
259 }
260}
261
262impl TryFrom<SinkParam> for SnowflakeV2Sink {
263 type Error = SinkError;
264
265 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
266 let schema = param.schema();
267 let config = SnowflakeV2Config::from_btreemap(¶m.properties)?;
268 let is_append_only = param.sink_type.is_append_only();
269 let pk_indices = param.downstream_pk_or_empty();
270 Ok(Self {
271 config,
272 schema,
273 pk_indices,
274 is_append_only,
275 param,
276 })
277 }
278}
279
280impl Sink for SnowflakeV2Sink {
281 type Coordinator = SnowflakeSinkCommitter;
282 type LogSinker = CoordinatedLogSinker<SnowflakeSinkWriter>;
283
284 const SINK_NAME: &'static str = SNOWFLAKE_SINK_V2;
285
286 async fn validate(&self) -> Result<()> {
287 risingwave_common::license::Feature::SnowflakeSink
288 .check_available()
289 .map_err(|e| anyhow::anyhow!(e))?;
290 if let Some((snowflake_task_ctx, client)) =
291 self.config.build_snowflake_task_ctx_jdbc_client(
292 self.is_append_only,
293 &self.schema,
294 &self.pk_indices,
295 )?
296 {
297 let client = SnowflakeJniClient::new(client, snowflake_task_ctx);
298 client.execute_create_table().await?;
299 }
300
301 Ok(())
302 }
303
304 fn support_schema_change() -> bool {
305 true
306 }
307
308 fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
309 SnowflakeV2Config::from_btreemap(config)?;
310 Ok(())
311 }
312
313 async fn new_log_sinker(
314 &self,
315 writer_param: crate::sink::SinkWriterParam,
316 ) -> Result<Self::LogSinker> {
317 let writer = SnowflakeSinkWriter::new(
318 self.config.clone(),
319 self.is_append_only,
320 writer_param.clone(),
321 self.param.clone(),
322 )
323 .await?;
324
325 let commit_checkpoint_interval =
326 NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
327 "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
328 );
329
330 CoordinatedLogSinker::new(
331 &writer_param,
332 self.param.clone(),
333 writer,
334 commit_checkpoint_interval,
335 )
336 .await
337 }
338
339 fn is_coordinated_sink(&self) -> bool {
340 true
341 }
342
343 async fn new_coordinator(
344 &self,
345 _db: DatabaseConnection,
346 _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
347 ) -> Result<Self::Coordinator> {
348 let coordinator = SnowflakeSinkCommitter::new(
349 self.config.clone(),
350 &self.schema,
351 &self.pk_indices,
352 self.is_append_only,
353 )?;
354 Ok(coordinator)
355 }
356}
357
358pub enum SnowflakeSinkWriter {
359 S3(SnowflakeRedshiftSinkS3Writer),
360 Jdbc(SnowflakeSinkJdbcWriter),
361}
362
363impl SnowflakeSinkWriter {
364 pub async fn new(
365 config: SnowflakeV2Config,
366 is_append_only: bool,
367 writer_param: SinkWriterParam,
368 param: SinkParam,
369 ) -> Result<Self> {
370 let schema = param.schema();
371 if config.with_s3 {
372 let executor_id = writer_param.executor_id;
373 let s3_writer = SnowflakeRedshiftSinkS3Writer::new(
374 config.s3_inner.ok_or_else(|| {
375 SinkError::Config(anyhow!(
376 "S3 configuration is required for Snowflake S3 sink"
377 ))
378 })?,
379 schema,
380 is_append_only,
381 executor_id,
382 config.snowflake_target_table_name,
383 )?;
384 Ok(Self::S3(s3_writer))
385 } else {
386 let jdbc_writer =
387 SnowflakeSinkJdbcWriter::new(config, is_append_only, writer_param, param).await?;
388 Ok(Self::Jdbc(jdbc_writer))
389 }
390 }
391}
392
393#[async_trait]
394impl SinkWriter for SnowflakeSinkWriter {
395 type CommitMetadata = Option<SinkMetadata>;
396
397 async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
398 match self {
399 Self::S3(writer) => writer.begin_epoch(epoch),
400 Self::Jdbc(writer) => writer.begin_epoch(epoch).await,
401 }
402 }
403
404 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
405 match self {
406 Self::S3(writer) => writer.write_batch(chunk).await,
407 Self::Jdbc(writer) => writer.write_batch(chunk).await,
408 }
409 }
410
411 async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
412 match self {
413 Self::S3(writer) => {
414 writer.barrier(is_checkpoint).await?;
415 }
416 Self::Jdbc(writer) => {
417 writer.barrier(is_checkpoint).await?;
418 }
419 }
420 Ok(Some(SinkMetadata {
421 metadata: Some(sink_metadata::Metadata::Serialized(
422 risingwave_pb::connector_service::sink_metadata::SerializedMetadata {
423 metadata: vec![],
424 },
425 )),
426 }))
427 }
428
429 async fn abort(&mut self) -> Result<()> {
430 if let Self::Jdbc(writer) = self {
431 writer.abort().await
432 } else {
433 Ok(())
434 }
435 }
436}
437
438pub struct SnowflakeSinkJdbcWriter {
439 augmented_row: AugmentedChunk,
440 jdbc_sink_writer: CoordinatedRemoteSinkWriter,
441}
442
443impl SnowflakeSinkJdbcWriter {
444 pub async fn new(
445 config: SnowflakeV2Config,
446 is_append_only: bool,
447 writer_param: SinkWriterParam,
448 mut param: SinkParam,
449 ) -> Result<Self> {
450 let metrics = SinkWriterMetrics::new(&writer_param);
451 let properties = ¶m.properties;
452 let column_descs = &mut param.columns;
453 let full_table_name = if is_append_only {
454 format!(
455 r#""{}"."{}"."{}""#,
456 config.snowflake_database.clone().unwrap_or_default(),
457 config.snowflake_schema.clone().unwrap_or_default(),
458 config
459 .snowflake_target_table_name
460 .clone()
461 .unwrap_or_default()
462 )
463 } else {
464 let max_column_id = column_descs
465 .iter()
466 .map(|column| column.column_id.get_id())
467 .max()
468 .unwrap_or(0);
469 (*column_descs).push(ColumnDesc::named(
470 SNOWFLAKE_SINK_ROW_ID,
471 ColumnId::new(max_column_id + 1),
472 DataType::Varchar,
473 ));
474 (*column_descs).push(ColumnDesc::named(
475 SNOWFLAKE_SINK_OP,
476 ColumnId::new(max_column_id + 2),
477 DataType::Int32,
478 ));
479 format!(
480 r#""{}"."{}"."{}""#,
481 config.snowflake_database.clone().unwrap_or_default(),
482 config.snowflake_schema.clone().unwrap_or_default(),
483 config.snowflake_cdc_table_name.clone().unwrap_or_default()
484 )
485 };
486 let new_properties = BTreeMap::from([
487 ("table.name".to_owned(), full_table_name),
488 ("connector".to_owned(), "snowflake_v2".to_owned()),
489 (
490 "jdbc.url".to_owned(),
491 config.jdbc_url.clone().unwrap_or_default(),
492 ),
493 ("type".to_owned(), "append-only".to_owned()),
494 (
495 "user".to_owned(),
496 config.username.clone().unwrap_or_default(),
497 ),
498 (
499 "password".to_owned(),
500 config.password.clone().unwrap_or_default(),
501 ),
502 (
503 "primary_key".to_owned(),
504 properties.get("primary_key").cloned().unwrap_or_default(),
505 ),
506 (
507 "schema.name".to_owned(),
508 config.snowflake_schema.clone().unwrap_or_default(),
509 ),
510 (
511 "database.name".to_owned(),
512 config.snowflake_database.clone().unwrap_or_default(),
513 ),
514 ]);
515 param.properties = new_properties;
516
517 let jdbc_sink_writer =
518 CoordinatedRemoteSinkWriter::new(param.clone(), metrics.clone()).await?;
519 Ok(Self {
520 augmented_row: AugmentedChunk::new(0, is_append_only),
521 jdbc_sink_writer,
522 })
523 }
524}
525
526impl SnowflakeSinkJdbcWriter {
527 async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
528 self.augmented_row.reset_epoch(epoch);
529 self.jdbc_sink_writer.begin_epoch(epoch).await?;
530 Ok(())
531 }
532
533 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
534 let chunk = self.augmented_row.augmented_chunk(chunk)?;
535 self.jdbc_sink_writer.write_batch(chunk).await?;
536 Ok(())
537 }
538
539 async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
540 self.jdbc_sink_writer.barrier(is_checkpoint).await?;
541 Ok(())
542 }
543
544 async fn abort(&mut self) -> Result<()> {
545 self.jdbc_sink_writer.abort().await?;
547 Ok(())
548 }
549}
550
551#[derive(Default)]
552pub struct SnowflakeTaskContext {
553 pub target_table_name: String,
555 pub database: String,
556 pub schema_name: String,
557 pub schema: Schema,
558
559 pub task_name: Option<String>,
561 pub cdc_table_name: Option<String>,
562 pub schedule_seconds: u64,
563 pub warehouse: Option<String>,
564 pub pk_column_names: Option<Vec<String>>,
565 pub all_column_names: Option<Vec<String>>,
566
567 pub stage: Option<String>,
569 pub pipe_name: Option<String>,
570}
571pub struct SnowflakeSinkCommitter {
572 client: Option<SnowflakeJniClient>,
573}
574
575impl SnowflakeSinkCommitter {
576 pub fn new(
577 config: SnowflakeV2Config,
578 schema: &Schema,
579 pk_indices: &Vec<usize>,
580 is_append_only: bool,
581 ) -> Result<Self> {
582 let client = if let Some((snowflake_task_ctx, client)) =
583 config.build_snowflake_task_ctx_jdbc_client(is_append_only, schema, pk_indices)?
584 {
585 Some(SnowflakeJniClient::new(client, snowflake_task_ctx))
586 } else {
587 None
588 };
589 Ok(Self { client })
590 }
591}
592
593#[async_trait]
594impl SinkCommitCoordinator for SnowflakeSinkCommitter {
595 async fn init(&mut self, _subscriber: SinkCommittedEpochSubscriber) -> Result<Option<u64>> {
596 if let Some(client) = &self.client {
597 client.execute_create_pipe().await?;
599 client.execute_create_merge_into_task().await?;
600 }
601 Ok(None)
602 }
603
604 async fn commit(
605 &mut self,
606 _epoch: u64,
607 _metadata: Vec<SinkMetadata>,
608 add_columns: Option<Vec<Field>>,
609 ) -> Result<()> {
610 let client = self.client.as_mut().ok_or_else(|| {
611 SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
612 })?;
613 client.execute_flush_pipe().await?;
614
615 if let Some(add_columns) = add_columns {
616 client
617 .execute_alter_add_columns(
618 &add_columns
619 .iter()
620 .map(|f| (f.name.clone(), f.data_type.to_string()))
621 .collect::<Vec<_>>(),
622 )
623 .await?;
624 }
625 Ok(())
626 }
627}
628
629impl Drop for SnowflakeSinkCommitter {
630 fn drop(&mut self) {
631 if let Some(client) = self.client.take() {
632 tokio::spawn(async move {
633 client.execute_drop_task().await.ok();
634 });
635 }
636 }
637}
638
639pub struct SnowflakeJniClient {
640 jdbc_client: JdbcJniClient,
641 snowflake_task_context: SnowflakeTaskContext,
642}
643
644impl SnowflakeJniClient {
645 pub fn new(jdbc_client: JdbcJniClient, snowflake_task_context: SnowflakeTaskContext) -> Self {
646 Self {
647 jdbc_client,
648 snowflake_task_context,
649 }
650 }
651
652 pub async fn execute_alter_add_columns(
653 &mut self,
654 columns: &Vec<(String, String)>,
655 ) -> Result<()> {
656 self.execute_drop_task().await?;
657 if let Some(names) = self.snowflake_task_context.all_column_names.as_mut() {
658 names.extend(columns.iter().map(|(name, _)| name.clone()));
659 }
660 if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
661 let alter_add_column_cdc_table_sql = build_alter_add_column_sql(
662 cdc_table_name,
663 &self.snowflake_task_context.database,
664 &self.snowflake_task_context.schema_name,
665 columns,
666 );
667 self.jdbc_client
668 .execute_sql_sync(vec![alter_add_column_cdc_table_sql])
669 .await?;
670 }
671
672 let alter_add_column_target_table_sql = build_alter_add_column_sql(
673 &self.snowflake_task_context.target_table_name,
674 &self.snowflake_task_context.database,
675 &self.snowflake_task_context.schema_name,
676 columns,
677 );
678 self.jdbc_client
679 .execute_sql_sync(vec![alter_add_column_target_table_sql])
680 .await?;
681
682 self.execute_create_merge_into_task().await?;
683 Ok(())
684 }
685
686 pub async fn execute_create_merge_into_task(&self) -> Result<()> {
687 if self.snowflake_task_context.task_name.is_some() {
688 let create_task_sql = build_create_merge_into_task_sql(&self.snowflake_task_context);
689 let start_task_sql = build_start_task_sql(&self.snowflake_task_context);
690 self.jdbc_client
691 .execute_sql_sync(vec![create_task_sql])
692 .await?;
693 self.jdbc_client
694 .execute_sql_sync(vec![start_task_sql])
695 .await?;
696 }
697 Ok(())
698 }
699
700 pub async fn execute_drop_task(&self) -> Result<()> {
701 if self.snowflake_task_context.task_name.is_some() {
702 let sql = build_drop_task_sql(&self.snowflake_task_context);
703 if let Err(e) = self.jdbc_client.execute_sql_sync(vec![sql]).await {
704 tracing::error!(
705 "Failed to drop Snowflake sink task {:?}: {:?}",
706 self.snowflake_task_context.task_name,
707 e.as_report()
708 );
709 } else {
710 tracing::info!(
711 "Snowflake sink task {:?} dropped",
712 self.snowflake_task_context.task_name
713 );
714 }
715 }
716 Ok(())
717 }
718
719 pub async fn execute_create_table(&self) -> Result<()> {
720 let create_target_table_sql = build_create_table_sql(
722 &self.snowflake_task_context.target_table_name,
723 &self.snowflake_task_context.database,
724 &self.snowflake_task_context.schema_name,
725 &self.snowflake_task_context.schema,
726 false,
727 )?;
728 self.jdbc_client
729 .execute_sql_sync(vec![create_target_table_sql])
730 .await?;
731 if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
732 let create_cdc_table_sql = build_create_table_sql(
733 cdc_table_name,
734 &self.snowflake_task_context.database,
735 &self.snowflake_task_context.schema_name,
736 &self.snowflake_task_context.schema,
737 true,
738 )?;
739 self.jdbc_client
740 .execute_sql_sync(vec![create_cdc_table_sql])
741 .await?;
742 }
743 Ok(())
744 }
745
746 pub async fn execute_create_pipe(&self) -> Result<()> {
747 if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
748 let table_name =
749 if let Some(table_name) = self.snowflake_task_context.cdc_table_name.as_ref() {
750 table_name
751 } else {
752 &self.snowflake_task_context.target_table_name
753 };
754 let create_pipe_sql = build_create_pipe_sql(
755 table_name,
756 &self.snowflake_task_context.database,
757 &self.snowflake_task_context.schema_name,
758 self.snowflake_task_context.stage.as_ref().ok_or_else(|| {
759 SinkError::Config(anyhow!("snowflake.stage is required for S3 writer"))
760 })?,
761 pipe_name,
762 &self.snowflake_task_context.target_table_name,
763 );
764 self.jdbc_client
765 .execute_sql_sync(vec![create_pipe_sql])
766 .await?;
767 }
768 Ok(())
769 }
770
771 pub async fn execute_flush_pipe(&self) -> Result<()> {
772 if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
773 let flush_pipe_sql = build_flush_pipe_sql(
774 &self.snowflake_task_context.database,
775 &self.snowflake_task_context.schema_name,
776 pipe_name,
777 );
778 self.jdbc_client
779 .execute_sql_sync(vec![flush_pipe_sql])
780 .await?;
781 }
782 Ok(())
783 }
784}
785
786fn build_create_table_sql(
787 table_name: &str,
788 database: &str,
789 schema_name: &str,
790 schema: &Schema,
791 need_op_and_row_id: bool,
792) -> Result<String> {
793 let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema_name, table_name);
794 let mut columns: Vec<String> = schema
795 .fields
796 .iter()
797 .map(|field| {
798 let data_type = convert_snowflake_data_type(&field.data_type)?;
799 Ok(format!(r#""{}" {}"#, field.name, data_type))
800 })
801 .collect::<Result<Vec<String>>>()?;
802 if need_op_and_row_id {
803 columns.push(format!(r#""{}" STRING"#, SNOWFLAKE_SINK_ROW_ID));
804 columns.push(format!(r#""{}" INT"#, SNOWFLAKE_SINK_OP));
805 }
806 let columns_str = columns.join(", ");
807 Ok(format!(
808 "CREATE TABLE IF NOT EXISTS {} ({}) ENABLE_SCHEMA_EVOLUTION = true",
809 full_table_name, columns_str
810 ))
811}
812
813fn convert_snowflake_data_type(data_type: &DataType) -> Result<String> {
814 let data_type = match data_type {
815 DataType::Int16 => "SMALLINT".to_owned(),
816 DataType::Int32 => "INTEGER".to_owned(),
817 DataType::Int64 => "BIGINT".to_owned(),
818 DataType::Float32 => "FLOAT4".to_owned(),
819 DataType::Float64 => "FLOAT8".to_owned(),
820 DataType::Boolean => "BOOLEAN".to_owned(),
821 DataType::Varchar => "STRING".to_owned(),
822 DataType::Date => "DATE".to_owned(),
823 DataType::Timestamp => "TIMESTAMP".to_owned(),
824 DataType::Timestamptz => "TIMESTAMP_TZ".to_owned(),
825 DataType::Jsonb => "STRING".to_owned(),
826 DataType::Decimal => "DECIMAL".to_owned(),
827 DataType::Bytea => "BINARY".to_owned(),
828 DataType::Time => "TIME".to_owned(),
829 _ => {
830 return Err(SinkError::Config(anyhow!(
831 "Dont support auto create table for datatype: {}",
832 data_type
833 )));
834 }
835 };
836 Ok(data_type)
837}
838
839fn build_create_pipe_sql(
840 table_name: &str,
841 database: &str,
842 schema: &str,
843 stage: &str,
844 pipe_name: &str,
845 target_table_name: &str,
846) -> String {
847 let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
848 let stage = format!(
849 r#""{}"."{}"."{}"/{}"#,
850 database, schema, stage, target_table_name
851 );
852 let table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
853 format!(
854 "CREATE OR REPLACE PIPE {} AUTO_INGEST = FALSE AS COPY INTO {} FROM @{} MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE FILE_FORMAT = (type = 'JSON');",
855 pipe_name, table_name, stage
856 )
857}
858
859fn build_flush_pipe_sql(database: &str, schema: &str, pipe_name: &str) -> String {
860 let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
861 format!("ALTER PIPE {} REFRESH;", pipe_name,)
862}
863
864fn build_alter_add_column_sql(
865 table_name: &str,
866 database: &str,
867 schema: &str,
868 columns: &Vec<(String, String)>,
869) -> String {
870 let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
871 jdbc_jni_client::build_alter_add_column_sql(&full_table_name, columns, true)
872}
873
874fn build_start_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
875 let SnowflakeTaskContext {
876 task_name,
877 database,
878 schema_name: schema,
879 ..
880 } = snowflake_task_context;
881 let full_task_name = format!(
882 r#""{}"."{}"."{}""#,
883 database,
884 schema,
885 task_name.as_ref().unwrap()
886 );
887 format!("ALTER TASK {} RESUME", full_task_name)
888}
889
890fn build_drop_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
891 let SnowflakeTaskContext {
892 task_name,
893 database,
894 schema_name: schema,
895 ..
896 } = snowflake_task_context;
897 let full_task_name = format!(
898 r#""{}"."{}"."{}""#,
899 database,
900 schema,
901 task_name.as_ref().unwrap()
902 );
903 format!("DROP TASK IF EXISTS {}", full_task_name)
904}
905
906fn build_create_merge_into_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
907 let SnowflakeTaskContext {
908 task_name,
909 cdc_table_name,
910 target_table_name,
911 schedule_seconds,
912 warehouse,
913 pk_column_names,
914 all_column_names,
915 database,
916 schema_name,
917 ..
918 } = snowflake_task_context;
919 let full_task_name = format!(
920 r#""{}"."{}"."{}""#,
921 database,
922 schema_name,
923 task_name.as_ref().unwrap()
924 );
925 let full_cdc_table_name = format!(
926 r#""{}"."{}"."{}""#,
927 database,
928 schema_name,
929 cdc_table_name.as_ref().unwrap()
930 );
931 let full_target_table_name = format!(
932 r#""{}"."{}"."{}""#,
933 database, schema_name, target_table_name
934 );
935
936 let pk_names_str = pk_column_names
937 .as_ref()
938 .unwrap()
939 .iter()
940 .map(|name| format!(r#""{}""#, name))
941 .collect::<Vec<String>>()
942 .join(", ");
943 let pk_names_eq_str = pk_column_names
944 .as_ref()
945 .unwrap()
946 .iter()
947 .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
948 .collect::<Vec<String>>()
949 .join(" AND ");
950 let all_column_names_set_str = all_column_names
951 .as_ref()
952 .unwrap()
953 .iter()
954 .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
955 .collect::<Vec<String>>()
956 .join(", ");
957 let all_column_names_str = all_column_names
958 .as_ref()
959 .unwrap()
960 .iter()
961 .map(|name| format!(r#""{}""#, name))
962 .collect::<Vec<String>>()
963 .join(", ");
964 let all_column_names_insert_str = all_column_names
965 .as_ref()
966 .unwrap()
967 .iter()
968 .map(|name| format!(r#"source."{}""#, name))
969 .collect::<Vec<String>>()
970 .join(", ");
971
972 format!(
973 r#"CREATE OR REPLACE TASK {task_name}
974WAREHOUSE = {warehouse}
975SCHEDULE = '{schedule_seconds} SECONDS'
976AS
977BEGIN
978 LET max_row_id STRING;
979
980 SELECT COALESCE(MAX("{snowflake_sink_row_id}"), '0') INTO :max_row_id
981 FROM {cdc_table_name};
982
983 MERGE INTO {target_table_name} AS target
984 USING (
985 SELECT *
986 FROM (
987 SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_names_str} ORDER BY "{snowflake_sink_row_id}" DESC) AS dedupe_id
988 FROM {cdc_table_name}
989 WHERE "{snowflake_sink_row_id}" <= :max_row_id
990 ) AS subquery
991 WHERE dedupe_id = 1
992 ) AS source
993 ON {pk_names_eq_str}
994 WHEN MATCHED AND source."{snowflake_sink_op}" IN (2, 4) THEN DELETE
995 WHEN MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN UPDATE SET {all_column_names_set_str}
996 WHEN NOT MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN INSERT ({all_column_names_str}) VALUES ({all_column_names_insert_str});
997
998 DELETE FROM {cdc_table_name}
999 WHERE "{snowflake_sink_row_id}" <= :max_row_id;
1000END;"#,
1001 task_name = full_task_name,
1002 warehouse = warehouse.as_ref().unwrap(),
1003 schedule_seconds = schedule_seconds,
1004 cdc_table_name = full_cdc_table_name,
1005 target_table_name = full_target_table_name,
1006 pk_names_str = pk_names_str,
1007 pk_names_eq_str = pk_names_eq_str,
1008 all_column_names_set_str = all_column_names_set_str,
1009 all_column_names_str = all_column_names_str,
1010 all_column_names_insert_str = all_column_names_insert_str,
1011 snowflake_sink_row_id = SNOWFLAKE_SINK_ROW_ID,
1012 snowflake_sink_op = SNOWFLAKE_SINK_OP,
1013 )
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018 use super::*;
1019 use crate::sink::jdbc_jni_client::normalize_sql;
1020
1021 #[test]
1022 fn test_snowflake_sink_commit_coordinator() {
1023 let snowflake_task_context = SnowflakeTaskContext {
1024 task_name: Some("test_task".to_owned()),
1025 cdc_table_name: Some("test_cdc_table".to_owned()),
1026 target_table_name: "test_target_table".to_owned(),
1027 schedule_seconds: 3600,
1028 warehouse: Some("test_warehouse".to_owned()),
1029 pk_column_names: Some(vec!["v1".to_owned()]),
1030 all_column_names: Some(vec!["v1".to_owned(), "v2".to_owned()]),
1031 database: "test_db".to_owned(),
1032 schema_name: "test_schema".to_owned(),
1033 schema: Schema { fields: vec![] },
1034 stage: None,
1035 pipe_name: None,
1036 };
1037 let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1038 let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task"
1039WAREHOUSE = test_warehouse
1040SCHEDULE = '3600 SECONDS'
1041AS
1042BEGIN
1043 LET max_row_id STRING;
1044
1045 SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1046 FROM "test_db"."test_schema"."test_cdc_table";
1047
1048 MERGE INTO "test_db"."test_schema"."test_target_table" AS target
1049 USING (
1050 SELECT *
1051 FROM (
1052 SELECT *, ROW_NUMBER() OVER (PARTITION BY "v1" ORDER BY "__row_id" DESC) AS dedupe_id
1053 FROM "test_db"."test_schema"."test_cdc_table"
1054 WHERE "__row_id" <= :max_row_id
1055 ) AS subquery
1056 WHERE dedupe_id = 1
1057 ) AS source
1058 ON target."v1" = source."v1"
1059 WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1060 WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."v1" = source."v1", target."v2" = source."v2"
1061 WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("v1", "v2") VALUES (source."v1", source."v2");
1062
1063 DELETE FROM "test_db"."test_schema"."test_cdc_table"
1064 WHERE "__row_id" <= :max_row_id;
1065END;"#;
1066 assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1067 }
1068
1069 #[test]
1070 fn test_snowflake_sink_commit_coordinator_multi_pk() {
1071 let snowflake_task_context = SnowflakeTaskContext {
1072 task_name: Some("test_task_multi_pk".to_owned()),
1073 cdc_table_name: Some("cdc_multi_pk".to_owned()),
1074 target_table_name: "target_multi_pk".to_owned(),
1075 schedule_seconds: 300,
1076 warehouse: Some("multi_pk_warehouse".to_owned()),
1077 pk_column_names: Some(vec!["id1".to_owned(), "id2".to_owned()]),
1078 all_column_names: Some(vec!["id1".to_owned(), "id2".to_owned(), "val".to_owned()]),
1079 database: "test_db".to_owned(),
1080 schema_name: "test_schema".to_owned(),
1081 schema: Schema { fields: vec![] },
1082 stage: None,
1083 pipe_name: None,
1084 };
1085 let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1086 let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task_multi_pk"
1087WAREHOUSE = multi_pk_warehouse
1088SCHEDULE = '300 SECONDS'
1089AS
1090BEGIN
1091 LET max_row_id STRING;
1092
1093 SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1094 FROM "test_db"."test_schema"."cdc_multi_pk";
1095
1096 MERGE INTO "test_db"."test_schema"."target_multi_pk" AS target
1097 USING (
1098 SELECT *
1099 FROM (
1100 SELECT *, ROW_NUMBER() OVER (PARTITION BY "id1", "id2" ORDER BY "__row_id" DESC) AS dedupe_id
1101 FROM "test_db"."test_schema"."cdc_multi_pk"
1102 WHERE "__row_id" <= :max_row_id
1103 ) AS subquery
1104 WHERE dedupe_id = 1
1105 ) AS source
1106 ON target."id1" = source."id1" AND target."id2" = source."id2"
1107 WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1108 WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."id1" = source."id1", target."id2" = source."id2", target."val" = source."val"
1109 WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("id1", "id2", "val") VALUES (source."id1", source."id2", source."val");
1110
1111 DELETE FROM "test_db"."test_schema"."cdc_multi_pk"
1112 WHERE "__row_id" <= :max_row_id;
1113END;"#;
1114 assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1115 }
1116}