1use core::num::NonZeroU64;
16use std::collections::BTreeMap;
17use std::time::Duration;
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use phf::{Set, phf_set};
22use risingwave_common::array::StreamChunk;
23use risingwave_common::catalog::Schema;
24use risingwave_common::types::DataType;
25use risingwave_pb::connector_service::{SinkMetadata, sink_metadata};
26use risingwave_pb::stream_plan::PbSinkSchemaChange;
27use serde::Deserialize;
28use serde_with::{DisplayFromStr, serde_as};
29use thiserror_ext::AsReport;
30use tokio::sync::mpsc::{UnboundedSender, unbounded_channel};
31use tokio::time::{MissedTickBehavior, interval};
32use tonic::async_trait;
33use with_options::WithOptions;
34
35use crate::connector_common::IcebergSinkCompactionUpdate;
36use crate::enforce_secret::EnforceSecret;
37use crate::sink::catalog::SinkId;
38use crate::sink::coordinate::CoordinatedLogSinker;
39use crate::sink::decouple_checkpoint_log_sink::default_commit_checkpoint_interval;
40use crate::sink::file_sink::s3::S3Common;
41use crate::sink::jdbc_jni_client::{self, JdbcJniClient};
42use crate::sink::snowflake_redshift::{
43 __OP, __ROW_ID, SnowflakeRedshiftSinkJdbcWriter, SnowflakeRedshiftSinkS3Writer,
44};
45use crate::sink::writer::SinkWriter;
46use crate::sink::{
47 Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT,
48 SinglePhaseCommitCoordinator, Sink, SinkCommitCoordinator, SinkError, SinkParam,
49 SinkWriterParam,
50};
51
52pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2";
53
54const AUTH_METHOD_PASSWORD: &str = "password";
55const AUTH_METHOD_KEY_PAIR_FILE: &str = "key_pair_file";
56const AUTH_METHOD_KEY_PAIR_OBJECT: &str = "key_pair_object";
57const PROP_AUTH_METHOD: &str = "auth.method";
58
59pub fn build_full_table_name(database: &str, schema_name: &str, table_name: &str) -> String {
60 format!(r#""{}"."{}"."{}""#, database, schema_name, table_name)
61}
62
63#[serde_as]
64#[derive(Debug, Clone, Deserialize, WithOptions)]
65pub struct SnowflakeV2Config {
66 #[serde(rename = "type")]
67 pub r#type: String,
68
69 #[serde(rename = "intermediate.table.name")]
70 pub snowflake_cdc_table_name: Option<String>,
71
72 #[serde(rename = "table.name")]
73 pub snowflake_target_table_name: Option<String>,
74
75 #[serde(rename = "database")]
76 pub snowflake_database: Option<String>,
77
78 #[serde(rename = "schema")]
79 pub snowflake_schema: Option<String>,
80
81 #[serde(default = "default_target_interval_schedule")]
82 #[serde(rename = "write.target.interval.seconds")]
83 #[serde_as(as = "DisplayFromStr")]
84 pub writer_target_interval_seconds: u64,
85
86 #[serde(default = "default_intermediate_interval_schedule")]
87 #[serde(rename = "write.intermediate.interval.seconds")]
88 #[serde_as(as = "DisplayFromStr")]
89 pub write_intermediate_interval_seconds: u64,
90
91 #[serde(rename = "warehouse")]
92 pub snowflake_warehouse: Option<String>,
93
94 #[serde(rename = "jdbc.url")]
95 pub jdbc_url: Option<String>,
96
97 #[serde(rename = "username")]
98 pub username: Option<String>,
99
100 #[serde(rename = "password")]
101 pub password: Option<String>,
102
103 #[serde(rename = "auth.method")]
105 pub auth_method: Option<String>,
106
107 #[serde(rename = "private_key_file")]
109 pub private_key_file: Option<String>,
110
111 #[serde(rename = "private_key_file_pwd")]
112 pub private_key_file_pwd: Option<String>,
113
114 #[serde(rename = "private_key_pem")]
116 pub private_key_pem: Option<String>,
117
118 #[serde(default = "default_commit_checkpoint_interval")]
120 #[serde_as(as = "DisplayFromStr")]
121 #[with_option(allow_alter_on_fly)]
122 pub commit_checkpoint_interval: u64,
123
124 #[serde(default)]
127 #[serde(rename = "auto.schema.change")]
128 #[serde_as(as = "DisplayFromStr")]
129 pub auto_schema_change: bool,
130
131 #[serde(default)]
132 #[serde(rename = "create_table_if_not_exists")]
133 #[serde_as(as = "DisplayFromStr")]
134 pub create_table_if_not_exists: bool,
135
136 #[serde(default = "default_with_s3")]
137 #[serde(rename = "with_s3")]
138 #[serde_as(as = "DisplayFromStr")]
139 pub with_s3: bool,
140
141 #[serde(flatten)]
142 pub s3_inner: Option<S3Common>,
143
144 #[serde(rename = "stage")]
145 pub stage: Option<String>,
146}
147
148fn default_target_interval_schedule() -> u64 {
149 3600 }
151
152fn default_intermediate_interval_schedule() -> u64 {
153 1800 }
155
156fn default_with_s3() -> bool {
157 true
158}
159
160impl SnowflakeV2Config {
161 pub fn build_jdbc_connection_properties(&self) -> Result<(String, Vec<(String, String)>)> {
167 let jdbc_url = self
168 .jdbc_url
169 .clone()
170 .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?;
171 let username = self
172 .username
173 .clone()
174 .ok_or(SinkError::Config(anyhow!("username is required")))?;
175
176 let mut connection_properties: Vec<(String, String)> = vec![("user".to_owned(), username)];
177
178 match self.auth_method.as_deref().unwrap() {
180 AUTH_METHOD_PASSWORD => {
181 connection_properties.push(("password".to_owned(), self.password.clone().unwrap()));
183 }
184 AUTH_METHOD_KEY_PAIR_FILE => {
185 connection_properties.push((
187 "private_key_file".to_owned(),
188 self.private_key_file.clone().unwrap(),
189 ));
190 if let Some(pwd) = self.private_key_file_pwd.clone() {
191 connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
192 }
193 }
194 AUTH_METHOD_KEY_PAIR_OBJECT => {
195 connection_properties.push((
196 PROP_AUTH_METHOD.to_owned(),
197 AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
198 ));
199 connection_properties.push((
201 "private_key_pem".to_owned(),
202 self.private_key_pem.clone().unwrap(),
203 ));
204 if let Some(pwd) = self.private_key_file_pwd.clone() {
205 connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
206 }
207 }
208 _ => {
209 unreachable!(
211 "Invalid auth_method - should have been caught during config validation"
212 )
213 }
214 }
215
216 Ok((jdbc_url, connection_properties))
217 }
218
219 pub fn from_btreemap(properties: &BTreeMap<String, String>) -> Result<Self> {
220 let mut config =
221 serde_json::from_value::<SnowflakeV2Config>(serde_json::to_value(properties).unwrap())
222 .map_err(|e| SinkError::Config(anyhow!(e)))?;
223 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
224 return Err(SinkError::Config(anyhow!(
225 "`{}` must be {}, or {}",
226 SINK_TYPE_OPTION,
227 SINK_TYPE_APPEND_ONLY,
228 SINK_TYPE_UPSERT
229 )));
230 }
231
232 let has_password = config.password.is_some();
234 let has_file = config.private_key_file.is_some();
235 let has_pem = config.private_key_pem.as_deref().is_some();
236
237 let normalized_auth_method = match config
238 .auth_method
239 .as_deref()
240 .map(|s| s.trim().to_ascii_lowercase())
241 {
242 Some(method) if method == AUTH_METHOD_PASSWORD => {
243 if !has_password {
244 return Err(SinkError::Config(anyhow!(
245 "auth.method=password requires `password`"
246 )));
247 }
248 if has_file || has_pem {
249 return Err(SinkError::Config(anyhow!(
250 "auth.method=password must not set `private_key_file`/`private_key_pem`"
251 )));
252 }
253 AUTH_METHOD_PASSWORD.to_owned()
254 }
255 Some(method) if method == AUTH_METHOD_KEY_PAIR_FILE => {
256 if !has_file {
257 return Err(SinkError::Config(anyhow!(
258 "auth.method=key_pair_file requires `private_key_file`"
259 )));
260 }
261 if has_password {
262 return Err(SinkError::Config(anyhow!(
263 "auth.method=key_pair_file must not set `password`"
264 )));
265 }
266 if has_pem {
267 return Err(SinkError::Config(anyhow!(
268 "auth.method=key_pair_file must not set `private_key_pem`"
269 )));
270 }
271 AUTH_METHOD_KEY_PAIR_FILE.to_owned()
272 }
273 Some(method) if method == AUTH_METHOD_KEY_PAIR_OBJECT => {
274 if !has_pem {
275 return Err(SinkError::Config(anyhow!(
276 "auth.method=key_pair_object requires `private_key_pem`"
277 )));
278 }
279 if has_password {
280 return Err(SinkError::Config(anyhow!(
281 "auth.method=key_pair_object must not set `password`"
282 )));
283 }
284 AUTH_METHOD_KEY_PAIR_OBJECT.to_owned()
285 }
286 Some(other) => {
287 return Err(SinkError::Config(anyhow!(
288 "invalid auth.method: {} (allowed: password | key_pair_file | key_pair_object)",
289 other
290 )));
291 }
292 None => {
293 match (has_password, has_file, has_pem) {
295 (true, false, false) => AUTH_METHOD_PASSWORD.to_owned(),
296 (false, true, false) => AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
297 (false, false, true) => AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
298 (true, true, _) | (true, _, true) | (false, true, true) => {
299 return Err(SinkError::Config(anyhow!(
300 "ambiguous auth: multiple auth options provided; remove one or set `auth.method`"
301 )));
302 }
303 _ => {
304 return Err(SinkError::Config(anyhow!(
305 "no authentication configured: set either `password`, or `private_key_file`, or `private_key_pem` (or provide `auth.method`)"
306 )));
307 }
308 }
309 }
310 };
311 config.auth_method = Some(normalized_auth_method);
312 Ok(config)
313 }
314
315 pub fn build_snowflake_task_ctx_jdbc_client(
316 &self,
317 is_append_only: bool,
318 schema: &Schema,
319 pk_indices: &Vec<usize>,
320 ) -> Result<Option<(SnowflakeTaskContext, JdbcJniClient)>> {
321 if !self.auto_schema_change
322 && is_append_only
323 && !self.create_table_if_not_exists
324 && !self.with_s3
325 {
326 return Ok(None);
328 }
329 let target_table_name = self
330 .snowflake_target_table_name
331 .clone()
332 .ok_or(SinkError::Config(anyhow!("table.name is required")))?;
333 let database = self
334 .snowflake_database
335 .clone()
336 .ok_or(SinkError::Config(anyhow!("database is required")))?;
337 let schema_name = self
338 .snowflake_schema
339 .clone()
340 .ok_or(SinkError::Config(anyhow!("schema is required")))?;
341 let mut snowflake_task_ctx = SnowflakeTaskContext {
342 target_table_name: target_table_name.clone(),
343 database,
344 schema_name,
345 schema: schema.clone(),
346 ..Default::default()
347 };
348
349 let (jdbc_url, connection_properties) = self.build_jdbc_connection_properties()?;
350 let client = JdbcJniClient::new_with_props(jdbc_url, connection_properties)?;
351
352 if self.with_s3 {
353 let stage = self
354 .stage
355 .clone()
356 .ok_or(SinkError::Config(anyhow!("stage is required")))?;
357 snowflake_task_ctx.stage = Some(stage);
358 snowflake_task_ctx.pipe_name = Some(format!("{}_pipe", target_table_name));
359 }
360 if !is_append_only {
361 let cdc_table_name = self
362 .snowflake_cdc_table_name
363 .clone()
364 .ok_or(SinkError::Config(anyhow!(
365 "intermediate.table.name is required"
366 )))?;
367 snowflake_task_ctx.cdc_table_name = Some(cdc_table_name.clone());
368 snowflake_task_ctx.writer_target_interval_seconds = self.writer_target_interval_seconds;
369 snowflake_task_ctx.warehouse = Some(
370 self.snowflake_warehouse
371 .clone()
372 .ok_or(SinkError::Config(anyhow!("warehouse is required")))?,
373 );
374 let pk_column_names: Vec<_> = schema
375 .fields
376 .iter()
377 .enumerate()
378 .filter(|(index, _)| pk_indices.contains(index))
379 .map(|(_, field)| field.name.clone())
380 .collect();
381 if pk_column_names.is_empty() {
382 return Err(SinkError::Config(anyhow!(
383 "Primary key columns not found. Please set the `primary_key` column in the sink properties, or ensure that the sink contains the primary key columns from the upstream."
384 )));
385 }
386 snowflake_task_ctx.pk_column_names = Some(pk_column_names);
387 snowflake_task_ctx.all_column_names = Some(
388 schema
389 .fields
390 .iter()
391 .map(|field| field.name.clone())
392 .collect(),
393 );
394 snowflake_task_ctx.task_name = Some(format!(
395 "rw_snowflake_sink_from_{cdc_table_name}_to_{target_table_name}"
396 ));
397 }
398 Ok(Some((snowflake_task_ctx, client)))
399 }
400}
401
402impl EnforceSecret for SnowflakeV2Config {
403 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
404 "username",
405 "password",
406 "jdbc.url",
407 "private_key_file_pwd",
409 "private_key_pem",
410 };
411}
412
413#[derive(Clone, Debug)]
414pub struct SnowflakeV2Sink {
415 config: SnowflakeV2Config,
416 schema: Schema,
417 pk_indices: Vec<usize>,
418 is_append_only: bool,
419 param: SinkParam,
420}
421
422impl EnforceSecret for SnowflakeV2Sink {
423 fn enforce_secret<'a>(
424 prop_iter: impl Iterator<Item = &'a str>,
425 ) -> crate::sink::ConnectorResult<()> {
426 for prop in prop_iter {
427 SnowflakeV2Config::enforce_one(prop)?;
428 }
429 Ok(())
430 }
431}
432
433impl TryFrom<SinkParam> for SnowflakeV2Sink {
434 type Error = SinkError;
435
436 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
437 let schema = param.schema();
438 let config = SnowflakeV2Config::from_btreemap(¶m.properties)?;
439 let is_append_only = param.sink_type.is_append_only();
440 let pk_indices = param.downstream_pk_or_empty();
441 Ok(Self {
442 config,
443 schema,
444 pk_indices,
445 is_append_only,
446 param,
447 })
448 }
449}
450
451impl Sink for SnowflakeV2Sink {
452 type LogSinker = CoordinatedLogSinker<SnowflakeSinkWriter>;
453
454 const SINK_NAME: &'static str = SNOWFLAKE_SINK_V2;
455
456 async fn validate(&self) -> Result<()> {
457 risingwave_common::license::Feature::SnowflakeSink
458 .check_available()
459 .map_err(|e| anyhow::anyhow!(e))?;
460 if let Some((snowflake_task_ctx, client)) =
461 self.config.build_snowflake_task_ctx_jdbc_client(
462 self.is_append_only,
463 &self.schema,
464 &self.pk_indices,
465 )?
466 {
467 let client = SnowflakeJniClient::new(client, snowflake_task_ctx);
468 client.execute_create_table().await?;
469 client.execute_create_pipe().await?;
470 }
471
472 Ok(())
473 }
474
475 fn support_schema_change() -> bool {
476 true
477 }
478
479 fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
480 SnowflakeV2Config::from_btreemap(config)?;
481 Ok(())
482 }
483
484 async fn new_log_sinker(
485 &self,
486 writer_param: crate::sink::SinkWriterParam,
487 ) -> Result<Self::LogSinker> {
488 let writer = SnowflakeSinkWriter::new(
489 self.config.clone(),
490 self.is_append_only,
491 writer_param.clone(),
492 self.param.clone(),
493 )
494 .await?;
495
496 let commit_checkpoint_interval =
497 NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
498 "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
499 );
500
501 CoordinatedLogSinker::new(
502 &writer_param,
503 self.param.clone(),
504 writer,
505 commit_checkpoint_interval,
506 )
507 .await
508 }
509
510 fn is_coordinated_sink(&self) -> bool {
511 true
512 }
513
514 async fn new_coordinator(
515 &self,
516 _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
517 ) -> Result<SinkCommitCoordinator> {
518 let coordinator = SnowflakeSinkCommitter::new(
519 self.config.clone(),
520 &self.schema,
521 &self.pk_indices,
522 self.is_append_only,
523 self.param.sink_id,
524 )?;
525 Ok(SinkCommitCoordinator::SinglePhase(Box::new(coordinator)))
526 }
527}
528
529pub enum SnowflakeSinkWriter {
530 S3(SnowflakeRedshiftSinkS3Writer),
531 Jdbc(SnowflakeRedshiftSinkJdbcWriter),
532}
533
534impl SnowflakeSinkWriter {
535 pub async fn new(
536 config: SnowflakeV2Config,
537 is_append_only: bool,
538 writer_param: SinkWriterParam,
539 param: SinkParam,
540 ) -> Result<Self> {
541 let schema = param.schema();
542 let database = config.snowflake_database.ok_or_else(|| {
543 SinkError::Config(anyhow!("database is required for Snowflake JDBC sink"))
544 })?;
545 let schema_name = config.snowflake_schema.ok_or_else(|| {
546 SinkError::Config(anyhow!("schema is required for Snowflake JDBC sink"))
547 })?;
548 let table_name = config.snowflake_target_table_name.ok_or_else(|| {
549 SinkError::Config(anyhow!("table.name is required for Snowflake JDBC sink"))
550 })?;
551 if config.with_s3 {
552 let s3_writer = SnowflakeRedshiftSinkS3Writer::new(
553 config.s3_inner.ok_or_else(|| {
554 SinkError::Config(anyhow!(
555 "S3 configuration is required for Snowflake S3 sink"
556 ))
557 })?,
558 schema,
559 is_append_only,
560 table_name,
561 )?;
562 Ok(Self::S3(s3_writer))
563 } else {
564 let jdbc_writer = SnowflakeRedshiftSinkJdbcWriter::new(
565 is_append_only,
566 writer_param,
567 param,
568 build_full_table_name(&database, &schema_name, &table_name),
569 )
570 .await?;
571 Ok(Self::Jdbc(jdbc_writer))
572 }
573 }
574}
575
576#[async_trait]
577impl SinkWriter for SnowflakeSinkWriter {
578 type CommitMetadata = Option<SinkMetadata>;
579
580 async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
581 match self {
582 Self::S3(writer) => writer.begin_epoch(epoch),
583 Self::Jdbc(writer) => writer.begin_epoch(epoch).await,
584 }
585 }
586
587 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
588 match self {
589 Self::S3(writer) => writer.write_batch(chunk).await,
590 Self::Jdbc(writer) => writer.write_batch(chunk).await,
591 }
592 }
593
594 async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
595 match self {
596 Self::S3(writer) => {
597 writer.barrier(is_checkpoint).await?;
598 }
599 Self::Jdbc(writer) => {
600 writer.barrier(is_checkpoint).await?;
601 }
602 }
603 Ok(Some(SinkMetadata {
604 metadata: Some(sink_metadata::Metadata::Serialized(
605 risingwave_pb::connector_service::sink_metadata::SerializedMetadata {
606 metadata: vec![],
607 },
608 )),
609 }))
610 }
611
612 async fn abort(&mut self) -> Result<()> {
613 if let Self::Jdbc(writer) = self {
614 writer.abort().await
615 } else {
616 Ok(())
617 }
618 }
619}
620
621#[derive(Default, Clone)]
622pub struct SnowflakeTaskContext {
623 pub target_table_name: String,
625 pub database: String,
626 pub schema_name: String,
627 pub schema: Schema,
628
629 pub task_name: Option<String>,
631 pub cdc_table_name: Option<String>,
632 pub writer_target_interval_seconds: u64,
633 pub warehouse: Option<String>,
634 pub pk_column_names: Option<Vec<String>>,
635 pub all_column_names: Option<Vec<String>>,
636
637 pub stage: Option<String>,
639 pub pipe_name: Option<String>,
640}
641pub struct SnowflakeSinkCommitter {
642 client: Option<SnowflakeJniClient>,
643 _periodic_task_handle: Option<tokio::task::JoinHandle<()>>,
644 shutdown_sender: Option<tokio::sync::mpsc::UnboundedSender<()>>,
645}
646
647impl SnowflakeSinkCommitter {
648 pub fn new(
649 config: SnowflakeV2Config,
650 schema: &Schema,
651 pk_indices: &Vec<usize>,
652 is_append_only: bool,
653 sink_id: SinkId,
654 ) -> Result<Self> {
655 let (client, periodic_task_handle, shutdown_sender) =
656 if let Some((snowflake_task_ctx, client)) =
657 config.build_snowflake_task_ctx_jdbc_client(is_append_only, schema, pk_indices)?
658 {
659 let (shutdown_sender, shutdown_receiver) = unbounded_channel();
660 let snowflake_client =
661 SnowflakeJniClient::new(client.clone(), snowflake_task_ctx.clone());
662 let periodic_task_handle = tokio::spawn(async move {
663 Self::run_periodic_query_task(
664 snowflake_client,
665 config.write_intermediate_interval_seconds,
666 sink_id,
667 shutdown_receiver,
668 )
669 .await;
670 });
671 (
672 Some(SnowflakeJniClient::new(client, snowflake_task_ctx)),
673 Some(periodic_task_handle),
674 Some(shutdown_sender),
675 )
676 } else {
677 (None, None, None)
678 };
679
680 Ok(Self {
681 client,
682 _periodic_task_handle: periodic_task_handle,
683 shutdown_sender,
684 })
685 }
686
687 async fn run_periodic_query_task(
688 client: SnowflakeJniClient,
689 write_intermediate_interval_seconds: u64,
690 sink_id: SinkId,
691 mut shutdown_receiver: tokio::sync::mpsc::UnboundedReceiver<()>,
692 ) {
693 let mut copy_timer = interval(Duration::from_secs(write_intermediate_interval_seconds));
694 copy_timer.set_missed_tick_behavior(MissedTickBehavior::Skip);
695 loop {
696 tokio::select! {
697 _ = shutdown_receiver.recv() => break,
698 _ = copy_timer.tick() => {
699 if let Err(e) = async {
700 client.execute_flush_pipe().await?;
701 Ok::<(),SinkError>(())
702 }.await {
703 tracing::error!("Failed to execute copy into task for sink id {}: {}", sink_id, e.as_report());
704 }
705 }
706 }
707 }
708 tracing::info!("Periodic query task stopped for sink id {}", sink_id);
709 }
710}
711
712#[async_trait]
713impl SinglePhaseCommitCoordinator for SnowflakeSinkCommitter {
714 async fn init(&mut self) -> Result<()> {
715 if let Some(client) = &self.client {
716 client.execute_create_pipe().await?;
718 client.execute_create_merge_into_task().await?;
719 }
720 Ok(())
721 }
722
723 async fn commit_data(&mut self, _epoch: u64, _metadata: Vec<SinkMetadata>) -> Result<()> {
724 Ok(())
725 }
726
727 async fn commit_schema_change(
728 &mut self,
729 _epoch: u64,
730 schema_change: PbSinkSchemaChange,
731 ) -> Result<()> {
732 use risingwave_pb::stream_plan::sink_schema_change::PbOp as SinkSchemaChangeOp;
733 let schema_change_op = schema_change
734 .op
735 .ok_or_else(|| SinkError::Coordinator(anyhow!("Invalid schema change operation")))?;
736 let SinkSchemaChangeOp::AddColumns(add_columns) = schema_change_op else {
737 return Err(SinkError::Coordinator(anyhow!(
738 "Only AddColumns schema change is supported for Snowflake sink"
739 )));
740 };
741 let client = self.client.as_mut().ok_or_else(|| {
742 SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
743 })?;
744 client
745 .execute_alter_add_columns(
746 &add_columns
747 .fields
748 .into_iter()
749 .map(|f| (f.name, DataType::from(f.data_type.unwrap()).to_string()))
750 .collect_vec(),
751 )
752 .await
753 }
754}
755
756impl Drop for SnowflakeSinkCommitter {
757 fn drop(&mut self) {
758 if let Some(client) = self.client.take() {
759 if let Some(sender) = self.shutdown_sender.take() {
760 let _ = sender.send(()); }
762 tokio::spawn(async move {
763 client.execute_drop_task().await.ok();
764 });
765 }
766 }
767}
768
769pub struct SnowflakeJniClient {
770 jdbc_client: JdbcJniClient,
771 snowflake_task_context: SnowflakeTaskContext,
772}
773
774impl SnowflakeJniClient {
775 pub fn new(jdbc_client: JdbcJniClient, snowflake_task_context: SnowflakeTaskContext) -> Self {
776 Self {
777 jdbc_client,
778 snowflake_task_context,
779 }
780 }
781
782 pub async fn execute_alter_add_columns(
783 &mut self,
784 columns: &Vec<(String, String)>,
785 ) -> Result<()> {
786 self.execute_drop_task().await?;
787 if let Some(names) = self.snowflake_task_context.all_column_names.as_mut() {
788 names.extend(columns.iter().map(|(name, _)| name.clone()));
789 }
790 if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
791 let alter_add_column_cdc_table_sql = build_alter_add_column_sql(
792 cdc_table_name,
793 &self.snowflake_task_context.database,
794 &self.snowflake_task_context.schema_name,
795 columns,
796 );
797 self.jdbc_client
798 .execute_sql_sync(vec![alter_add_column_cdc_table_sql])
799 .await?;
800 }
801
802 let alter_add_column_target_table_sql = build_alter_add_column_sql(
803 &self.snowflake_task_context.target_table_name,
804 &self.snowflake_task_context.database,
805 &self.snowflake_task_context.schema_name,
806 columns,
807 );
808 self.jdbc_client
809 .execute_sql_sync(vec![alter_add_column_target_table_sql])
810 .await?;
811
812 self.execute_create_merge_into_task().await?;
813 Ok(())
814 }
815
816 pub async fn execute_create_merge_into_task(&self) -> Result<()> {
817 if self.snowflake_task_context.task_name.is_some() {
818 let create_task_sql = build_create_merge_into_task_sql(&self.snowflake_task_context);
819 let start_task_sql = build_start_task_sql(&self.snowflake_task_context);
820 self.jdbc_client
821 .execute_sql_sync(vec![create_task_sql])
822 .await?;
823 self.jdbc_client
824 .execute_sql_sync(vec![start_task_sql])
825 .await?;
826 }
827 Ok(())
828 }
829
830 pub async fn execute_drop_task(&self) -> Result<()> {
831 if self.snowflake_task_context.task_name.is_some() {
832 let sql = build_drop_task_sql(&self.snowflake_task_context);
833 if let Err(e) = self.jdbc_client.execute_sql_sync(vec![sql]).await {
834 tracing::error!(
835 "Failed to drop Snowflake sink task {:?}: {:?}",
836 self.snowflake_task_context.task_name,
837 e.as_report()
838 );
839 } else {
840 tracing::info!(
841 "Snowflake sink task {:?} dropped",
842 self.snowflake_task_context.task_name
843 );
844 }
845 }
846 Ok(())
847 }
848
849 pub async fn execute_create_table(&self) -> Result<()> {
850 let create_target_table_sql = build_create_table_sql(
852 &self.snowflake_task_context.target_table_name,
853 &self.snowflake_task_context.database,
854 &self.snowflake_task_context.schema_name,
855 &self.snowflake_task_context.schema,
856 false,
857 )?;
858 self.jdbc_client
859 .execute_sql_sync(vec![create_target_table_sql])
860 .await?;
861 if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
862 let create_cdc_table_sql = build_create_table_sql(
863 cdc_table_name,
864 &self.snowflake_task_context.database,
865 &self.snowflake_task_context.schema_name,
866 &self.snowflake_task_context.schema,
867 true,
868 )?;
869 self.jdbc_client
870 .execute_sql_sync(vec![create_cdc_table_sql])
871 .await?;
872 }
873 Ok(())
874 }
875
876 pub async fn execute_create_pipe(&self) -> Result<()> {
877 if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
878 let table_name =
879 if let Some(table_name) = self.snowflake_task_context.cdc_table_name.as_ref() {
880 table_name
881 } else {
882 &self.snowflake_task_context.target_table_name
883 };
884 let create_pipe_sql = build_create_pipe_sql(
885 table_name,
886 &self.snowflake_task_context.database,
887 &self.snowflake_task_context.schema_name,
888 self.snowflake_task_context.stage.as_ref().ok_or_else(|| {
889 SinkError::Config(anyhow!("snowflake.stage is required for S3 writer"))
890 })?,
891 pipe_name,
892 &self.snowflake_task_context.target_table_name,
893 );
894 self.jdbc_client
895 .execute_sql_sync(vec![create_pipe_sql])
896 .await?;
897 }
898 Ok(())
899 }
900
901 pub async fn execute_flush_pipe(&self) -> Result<()> {
902 if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
903 let flush_pipe_sql = build_flush_pipe_sql(
904 &self.snowflake_task_context.database,
905 &self.snowflake_task_context.schema_name,
906 pipe_name,
907 );
908 self.jdbc_client
909 .execute_sql_sync(vec![flush_pipe_sql])
910 .await?;
911 }
912 Ok(())
913 }
914}
915
916fn build_create_table_sql(
917 table_name: &str,
918 database: &str,
919 schema_name: &str,
920 schema: &Schema,
921 need_op_and_row_id: bool,
922) -> Result<String> {
923 let full_table_name = build_full_table_name(database, schema_name, table_name);
924 let mut columns: Vec<String> = schema
925 .fields
926 .iter()
927 .map(|field| {
928 let data_type = convert_snowflake_data_type(&field.data_type)?;
929 Ok(format!(r#""{}" {}"#, field.name, data_type))
930 })
931 .collect::<Result<Vec<String>>>()?;
932 if need_op_and_row_id {
933 columns.push(format!(r#""{}" STRING"#, __ROW_ID));
934 columns.push(format!(r#""{}" INT"#, __OP));
935 }
936 let columns_str = columns.join(", ");
937 Ok(format!(
938 "CREATE TABLE IF NOT EXISTS {} ({}) ENABLE_SCHEMA_EVOLUTION = true",
939 full_table_name, columns_str
940 ))
941}
942
943fn convert_snowflake_data_type(data_type: &DataType) -> Result<String> {
944 let data_type = match data_type {
945 DataType::Int16 => "SMALLINT".to_owned(),
946 DataType::Int32 => "INTEGER".to_owned(),
947 DataType::Int64 => "BIGINT".to_owned(),
948 DataType::Float32 => "FLOAT4".to_owned(),
949 DataType::Float64 => "FLOAT8".to_owned(),
950 DataType::Boolean => "BOOLEAN".to_owned(),
951 DataType::Varchar => "STRING".to_owned(),
952 DataType::Date => "DATE".to_owned(),
953 DataType::Timestamp => "TIMESTAMP".to_owned(),
954 DataType::Timestamptz => "TIMESTAMP_TZ".to_owned(),
955 DataType::Jsonb => "STRING".to_owned(),
956 DataType::Decimal => "DECIMAL".to_owned(),
957 DataType::Bytea => "BINARY".to_owned(),
958 DataType::Time => "TIME".to_owned(),
959 _ => {
960 return Err(SinkError::Config(anyhow!(
961 "Dont support auto create table for datatype: {}",
962 data_type
963 )));
964 }
965 };
966 Ok(data_type)
967}
968
969fn build_create_pipe_sql(
970 table_name: &str,
971 database: &str,
972 schema: &str,
973 stage: &str,
974 pipe_name: &str,
975 target_table_name: &str,
976) -> String {
977 let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
978 let stage = format!(
979 r#""{}"."{}"."{}"/{}"#,
980 database, schema, stage, target_table_name
981 );
982 let table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
983 format!(
984 "CREATE OR REPLACE PIPE {} AUTO_INGEST = FALSE AS COPY INTO {} FROM @{} MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE FILE_FORMAT = (type = 'JSON');",
985 pipe_name, table_name, stage
986 )
987}
988
989fn build_flush_pipe_sql(database: &str, schema: &str, pipe_name: &str) -> String {
990 let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
991 format!("ALTER PIPE {} REFRESH;", pipe_name,)
992}
993
994fn build_alter_add_column_sql(
995 table_name: &str,
996 database: &str,
997 schema: &str,
998 columns: &Vec<(String, String)>,
999) -> String {
1000 let full_table_name = build_full_table_name(database, schema, table_name);
1001 jdbc_jni_client::build_alter_add_column_sql(&full_table_name, columns, true)
1002}
1003
1004fn build_start_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1005 let SnowflakeTaskContext {
1006 task_name,
1007 database,
1008 schema_name: schema,
1009 ..
1010 } = snowflake_task_context;
1011 let full_task_name = format!(
1012 r#""{}"."{}"."{}""#,
1013 database,
1014 schema,
1015 task_name.as_ref().unwrap()
1016 );
1017 format!("ALTER TASK {} RESUME", full_task_name)
1018}
1019
1020fn build_drop_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1021 let SnowflakeTaskContext {
1022 task_name,
1023 database,
1024 schema_name: schema,
1025 ..
1026 } = snowflake_task_context;
1027 let full_task_name = format!(
1028 r#""{}"."{}"."{}""#,
1029 database,
1030 schema,
1031 task_name.as_ref().unwrap()
1032 );
1033 format!("DROP TASK IF EXISTS {}", full_task_name)
1034}
1035
1036fn build_create_merge_into_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1037 let SnowflakeTaskContext {
1038 task_name,
1039 cdc_table_name,
1040 target_table_name,
1041 writer_target_interval_seconds,
1042 warehouse,
1043 pk_column_names,
1044 all_column_names,
1045 database,
1046 schema_name,
1047 ..
1048 } = snowflake_task_context;
1049 let full_task_name = format!(
1050 r#""{}"."{}"."{}""#,
1051 database,
1052 schema_name,
1053 task_name.as_ref().unwrap()
1054 );
1055 let full_cdc_table_name = format!(
1056 r#""{}"."{}"."{}""#,
1057 database,
1058 schema_name,
1059 cdc_table_name.as_ref().unwrap()
1060 );
1061 let full_target_table_name = format!(
1062 r#""{}"."{}"."{}""#,
1063 database, schema_name, target_table_name
1064 );
1065
1066 let pk_names_str = pk_column_names
1067 .as_ref()
1068 .unwrap()
1069 .iter()
1070 .map(|name| format!(r#""{}""#, name))
1071 .collect::<Vec<String>>()
1072 .join(", ");
1073 let pk_names_eq_str = pk_column_names
1074 .as_ref()
1075 .unwrap()
1076 .iter()
1077 .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1078 .collect::<Vec<String>>()
1079 .join(" AND ");
1080 let all_column_names_set_str = all_column_names
1081 .as_ref()
1082 .unwrap()
1083 .iter()
1084 .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1085 .collect::<Vec<String>>()
1086 .join(", ");
1087 let all_column_names_str = all_column_names
1088 .as_ref()
1089 .unwrap()
1090 .iter()
1091 .map(|name| format!(r#""{}""#, name))
1092 .collect::<Vec<String>>()
1093 .join(", ");
1094 let all_column_names_insert_str = all_column_names
1095 .as_ref()
1096 .unwrap()
1097 .iter()
1098 .map(|name| format!(r#"source."{}""#, name))
1099 .collect::<Vec<String>>()
1100 .join(", ");
1101
1102 format!(
1103 r#"CREATE OR REPLACE TASK {task_name}
1104WAREHOUSE = {warehouse}
1105SCHEDULE = '{writer_target_interval_seconds} SECONDS'
1106AS
1107BEGIN
1108 LET max_row_id STRING;
1109
1110 SELECT COALESCE(MAX("{snowflake_sink_row_id}"), '0') INTO :max_row_id
1111 FROM {cdc_table_name};
1112
1113 MERGE INTO {target_table_name} AS target
1114 USING (
1115 SELECT *
1116 FROM (
1117 SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_names_str} ORDER BY "{snowflake_sink_row_id}" DESC) AS dedupe_id
1118 FROM {cdc_table_name}
1119 WHERE "{snowflake_sink_row_id}" <= :max_row_id
1120 ) AS subquery
1121 WHERE dedupe_id = 1
1122 ) AS source
1123 ON {pk_names_eq_str}
1124 WHEN MATCHED AND source."{snowflake_sink_op}" IN (2, 4) THEN DELETE
1125 WHEN MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN UPDATE SET {all_column_names_set_str}
1126 WHEN NOT MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN INSERT ({all_column_names_str}) VALUES ({all_column_names_insert_str});
1127
1128 DELETE FROM {cdc_table_name}
1129 WHERE "{snowflake_sink_row_id}" <= :max_row_id;
1130END;"#,
1131 task_name = full_task_name,
1132 warehouse = warehouse.as_ref().unwrap(),
1133 writer_target_interval_seconds = writer_target_interval_seconds,
1134 cdc_table_name = full_cdc_table_name,
1135 target_table_name = full_target_table_name,
1136 pk_names_str = pk_names_str,
1137 pk_names_eq_str = pk_names_eq_str,
1138 all_column_names_set_str = all_column_names_set_str,
1139 all_column_names_str = all_column_names_str,
1140 all_column_names_insert_str = all_column_names_insert_str,
1141 snowflake_sink_row_id = __ROW_ID,
1142 snowflake_sink_op = __OP,
1143 )
1144}
1145
1146#[cfg(test)]
1147mod tests {
1148 use std::collections::BTreeMap;
1149
1150 use super::*;
1151 use crate::sink::jdbc_jni_client::normalize_sql;
1152
1153 fn base_properties() -> BTreeMap<String, String> {
1154 BTreeMap::from([
1155 ("type".to_owned(), "append-only".to_owned()),
1156 ("jdbc.url".to_owned(), "jdbc:snowflake://account".to_owned()),
1157 ("username".to_owned(), "RW_USER".to_owned()),
1158 ])
1159 }
1160
1161 #[test]
1162 fn test_build_jdbc_props_password() {
1163 let mut props = base_properties();
1164 props.insert("password".to_owned(), "secret".to_owned());
1165 let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1166 let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1167 assert_eq!(url, "jdbc:snowflake://account");
1168 let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1169 assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1170 assert_eq!(map.get("password"), Some(&"secret".to_owned()));
1171 assert!(!map.contains_key("authenticator"));
1172 }
1173
1174 #[test]
1175 fn test_build_jdbc_props_key_pair_file() {
1176 let mut props = base_properties();
1177 props.insert(
1178 "auth.method".to_owned(),
1179 AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
1180 );
1181 props.insert("private_key_file".to_owned(), "/tmp/rsa_key.p8".to_owned());
1182 props.insert("private_key_file_pwd".to_owned(), "dummy".to_owned());
1183 let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1184 let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1185 assert_eq!(url, "jdbc:snowflake://account");
1186 let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1187 assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1188 assert_eq!(
1189 map.get("private_key_file"),
1190 Some(&"/tmp/rsa_key.p8".to_owned())
1191 );
1192 assert_eq!(map.get("private_key_file_pwd"), Some(&"dummy".to_owned()));
1193 }
1194
1195 #[test]
1196 fn test_build_jdbc_props_key_pair_object() {
1197 let mut props = base_properties();
1198 props.insert(
1199 "auth.method".to_owned(),
1200 AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
1201 );
1202 props.insert(
1203 "private_key_pem".to_owned(),
1204 "-----BEGIN PRIVATE KEY-----
1205...
1206-----END PRIVATE KEY-----"
1207 .to_owned(),
1208 );
1209 let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1210 let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1211 assert_eq!(url, "jdbc:snowflake://account");
1212 let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1213 assert_eq!(
1214 map.get("private_key_pem"),
1215 Some(
1216 &"-----BEGIN PRIVATE KEY-----
1217...
1218-----END PRIVATE KEY-----"
1219 .to_owned()
1220 )
1221 );
1222 assert!(!map.contains_key("private_key_file"));
1223 }
1224
1225 #[test]
1226 fn test_snowflake_sink_commit_coordinator() {
1227 let snowflake_task_context = SnowflakeTaskContext {
1228 task_name: Some("test_task".to_owned()),
1229 cdc_table_name: Some("test_cdc_table".to_owned()),
1230 target_table_name: "test_target_table".to_owned(),
1231 writer_target_interval_seconds: 3600,
1232 warehouse: Some("test_warehouse".to_owned()),
1233 pk_column_names: Some(vec!["v1".to_owned()]),
1234 all_column_names: Some(vec!["v1".to_owned(), "v2".to_owned()]),
1235 database: "test_db".to_owned(),
1236 schema_name: "test_schema".to_owned(),
1237 schema: Schema { fields: vec![] },
1238 stage: None,
1239 pipe_name: None,
1240 };
1241 let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1242 let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task"
1243WAREHOUSE = test_warehouse
1244SCHEDULE = '3600 SECONDS'
1245AS
1246BEGIN
1247 LET max_row_id STRING;
1248
1249 SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1250 FROM "test_db"."test_schema"."test_cdc_table";
1251
1252 MERGE INTO "test_db"."test_schema"."test_target_table" AS target
1253 USING (
1254 SELECT *
1255 FROM (
1256 SELECT *, ROW_NUMBER() OVER (PARTITION BY "v1" ORDER BY "__row_id" DESC) AS dedupe_id
1257 FROM "test_db"."test_schema"."test_cdc_table"
1258 WHERE "__row_id" <= :max_row_id
1259 ) AS subquery
1260 WHERE dedupe_id = 1
1261 ) AS source
1262 ON target."v1" = source."v1"
1263 WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1264 WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."v1" = source."v1", target."v2" = source."v2"
1265 WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("v1", "v2") VALUES (source."v1", source."v2");
1266
1267 DELETE FROM "test_db"."test_schema"."test_cdc_table"
1268 WHERE "__row_id" <= :max_row_id;
1269END;"#;
1270 assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1271 }
1272
1273 #[test]
1274 fn test_snowflake_sink_commit_coordinator_multi_pk() {
1275 let snowflake_task_context = SnowflakeTaskContext {
1276 task_name: Some("test_task_multi_pk".to_owned()),
1277 cdc_table_name: Some("cdc_multi_pk".to_owned()),
1278 target_table_name: "target_multi_pk".to_owned(),
1279 writer_target_interval_seconds: 300,
1280 warehouse: Some("multi_pk_warehouse".to_owned()),
1281 pk_column_names: Some(vec!["id1".to_owned(), "id2".to_owned()]),
1282 all_column_names: Some(vec!["id1".to_owned(), "id2".to_owned(), "val".to_owned()]),
1283 database: "test_db".to_owned(),
1284 schema_name: "test_schema".to_owned(),
1285 schema: Schema { fields: vec![] },
1286 stage: None,
1287 pipe_name: None,
1288 };
1289 let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1290 let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task_multi_pk"
1291WAREHOUSE = multi_pk_warehouse
1292SCHEDULE = '300 SECONDS'
1293AS
1294BEGIN
1295 LET max_row_id STRING;
1296
1297 SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1298 FROM "test_db"."test_schema"."cdc_multi_pk";
1299
1300 MERGE INTO "test_db"."test_schema"."target_multi_pk" AS target
1301 USING (
1302 SELECT *
1303 FROM (
1304 SELECT *, ROW_NUMBER() OVER (PARTITION BY "id1", "id2" ORDER BY "__row_id" DESC) AS dedupe_id
1305 FROM "test_db"."test_schema"."cdc_multi_pk"
1306 WHERE "__row_id" <= :max_row_id
1307 ) AS subquery
1308 WHERE dedupe_id = 1
1309 ) AS source
1310 ON target."id1" = source."id1" AND target."id2" = source."id2"
1311 WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1312 WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."id1" = source."id1", target."id2" = source."id2", target."val" = source."val"
1313 WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("id1", "id2", "val") VALUES (source."id1", source."id2", source."val");
1314
1315 DELETE FROM "test_db"."test_schema"."cdc_multi_pk"
1316 WHERE "__row_id" <= :max_row_id;
1317END;"#;
1318 assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1319 }
1320}