1use core::num::NonZeroU64;
16use std::collections::BTreeMap;
17
18use anyhow::anyhow;
19use phf::{Set, phf_set};
20use risingwave_common::array::StreamChunk;
21use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
22use risingwave_common::types::DataType;
23use risingwave_pb::connector_service::{SinkMetadata, sink_metadata};
24use sea_orm::DatabaseConnection;
25use serde::Deserialize;
26use serde_with::{DisplayFromStr, serde_as};
27use thiserror_ext::AsReport;
28use tokio::sync::mpsc::UnboundedSender;
29use tonic::async_trait;
30use with_options::WithOptions;
31
32use crate::connector_common::IcebergSinkCompactionUpdate;
33use crate::enforce_secret::EnforceSecret;
34use crate::sink::coordinate::CoordinatedLogSinker;
35use crate::sink::decouple_checkpoint_log_sink::default_commit_checkpoint_interval;
36use crate::sink::file_sink::s3::S3Common;
37use crate::sink::jdbc_jni_client::{self, JdbcJniClient};
38use crate::sink::remote::CoordinatedRemoteSinkWriter;
39use crate::sink::snowflake_redshift::{AugmentedChunk, SnowflakeRedshiftSinkS3Writer};
40use crate::sink::writer::SinkWriter;
41use crate::sink::{
42 Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, Sink, SinkCommitCoordinator,
43 SinkCommittedEpochSubscriber, SinkError, SinkParam, SinkWriterMetrics, SinkWriterParam,
44};
45
46pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2";
47pub const SNOWFLAKE_SINK_ROW_ID: &str = "__row_id";
48pub const SNOWFLAKE_SINK_OP: &str = "__op";
49
50const AUTH_METHOD_PASSWORD: &str = "password";
51const AUTH_METHOD_KEY_PAIR_FILE: &str = "key_pair_file";
52const AUTH_METHOD_KEY_PAIR_OBJECT: &str = "key_pair_object";
53const PROP_AUTH_METHOD: &str = "auth.method";
54
55#[serde_as]
56#[derive(Debug, Clone, Deserialize, WithOptions)]
57pub struct SnowflakeV2Config {
58 #[serde(rename = "type")]
59 pub r#type: String,
60
61 #[serde(rename = "intermediate.table.name")]
62 pub snowflake_cdc_table_name: Option<String>,
63
64 #[serde(rename = "table.name")]
65 pub snowflake_target_table_name: Option<String>,
66
67 #[serde(rename = "database")]
68 pub snowflake_database: Option<String>,
69
70 #[serde(rename = "schema")]
71 pub snowflake_schema: Option<String>,
72
73 #[serde(default = "default_schedule")]
74 #[serde(rename = "write.target.interval.seconds")]
75 #[serde_as(as = "DisplayFromStr")]
76 pub snowflake_schedule_seconds: u64,
77
78 #[serde(rename = "warehouse")]
79 pub snowflake_warehouse: Option<String>,
80
81 #[serde(rename = "jdbc.url")]
82 pub jdbc_url: Option<String>,
83
84 #[serde(rename = "username")]
85 pub username: Option<String>,
86
87 #[serde(rename = "password")]
88 pub password: Option<String>,
89
90 #[serde(rename = "auth.method")]
92 pub auth_method: Option<String>,
93
94 #[serde(rename = "private_key_file")]
96 pub private_key_file: Option<String>,
97
98 #[serde(rename = "private_key_file_pwd")]
99 pub private_key_file_pwd: Option<String>,
100
101 #[serde(rename = "private_key_pem")]
103 pub private_key_pem: Option<String>,
104
105 #[serde(default = "default_commit_checkpoint_interval")]
107 #[serde_as(as = "DisplayFromStr")]
108 #[with_option(allow_alter_on_fly)]
109 pub commit_checkpoint_interval: u64,
110
111 #[serde(default)]
114 #[serde(rename = "auto.schema.change")]
115 #[serde_as(as = "DisplayFromStr")]
116 pub auto_schema_change: bool,
117
118 #[serde(default)]
119 #[serde(rename = "create_table_if_not_exists")]
120 #[serde_as(as = "DisplayFromStr")]
121 pub create_table_if_not_exists: bool,
122
123 #[serde(default = "default_with_s3")]
124 #[serde(rename = "with_s3")]
125 #[serde_as(as = "DisplayFromStr")]
126 pub with_s3: bool,
127
128 #[serde(flatten)]
129 pub s3_inner: Option<S3Common>,
130
131 #[serde(rename = "stage")]
132 pub stage: Option<String>,
133}
134
135fn default_schedule() -> u64 {
136 3600 }
138
139fn default_with_s3() -> bool {
140 true
141}
142
143impl SnowflakeV2Config {
144 pub fn build_jdbc_connection_properties(&self) -> Result<(String, Vec<(String, String)>)> {
150 let jdbc_url = self
151 .jdbc_url
152 .clone()
153 .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?;
154 let username = self
155 .username
156 .clone()
157 .ok_or(SinkError::Config(anyhow!("username is required")))?;
158
159 let mut connection_properties: Vec<(String, String)> = vec![("user".to_owned(), username)];
160
161 match self.auth_method.as_deref().unwrap() {
163 AUTH_METHOD_PASSWORD => {
164 connection_properties.push(("password".to_owned(), self.password.clone().unwrap()));
166 }
167 AUTH_METHOD_KEY_PAIR_FILE => {
168 connection_properties.push((
170 "private_key_file".to_owned(),
171 self.private_key_file.clone().unwrap(),
172 ));
173 if let Some(pwd) = self.private_key_file_pwd.clone() {
174 connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
175 }
176 }
177 AUTH_METHOD_KEY_PAIR_OBJECT => {
178 connection_properties.push((
179 PROP_AUTH_METHOD.to_owned(),
180 AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
181 ));
182 connection_properties.push((
184 "private_key_pem".to_owned(),
185 self.private_key_pem.clone().unwrap(),
186 ));
187 if let Some(pwd) = self.private_key_file_pwd.clone() {
188 connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
189 }
190 }
191 _ => {
192 unreachable!(
194 "Invalid auth_method - should have been caught during config validation"
195 )
196 }
197 }
198
199 Ok((jdbc_url, connection_properties))
200 }
201
202 pub fn from_btreemap(properties: &BTreeMap<String, String>) -> Result<Self> {
203 let mut config =
204 serde_json::from_value::<SnowflakeV2Config>(serde_json::to_value(properties).unwrap())
205 .map_err(|e| SinkError::Config(anyhow!(e)))?;
206 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
207 return Err(SinkError::Config(anyhow!(
208 "`{}` must be {}, or {}",
209 SINK_TYPE_OPTION,
210 SINK_TYPE_APPEND_ONLY,
211 SINK_TYPE_UPSERT
212 )));
213 }
214
215 let has_password = config.password.is_some();
217 let has_file = config.private_key_file.is_some();
218 let has_pem = config.private_key_pem.as_deref().is_some();
219
220 let normalized_auth_method = match config
221 .auth_method
222 .as_deref()
223 .map(|s| s.trim().to_ascii_lowercase())
224 {
225 Some(method) if method == AUTH_METHOD_PASSWORD => {
226 if !has_password {
227 return Err(SinkError::Config(anyhow!(
228 "auth.method=password requires `password`"
229 )));
230 }
231 if has_file || has_pem {
232 return Err(SinkError::Config(anyhow!(
233 "auth.method=password must not set `private_key_file`/`private_key_pem`"
234 )));
235 }
236 AUTH_METHOD_PASSWORD.to_owned()
237 }
238 Some(method) if method == AUTH_METHOD_KEY_PAIR_FILE => {
239 if !has_file {
240 return Err(SinkError::Config(anyhow!(
241 "auth.method=key_pair_file requires `private_key_file`"
242 )));
243 }
244 if has_password {
245 return Err(SinkError::Config(anyhow!(
246 "auth.method=key_pair_file must not set `password`"
247 )));
248 }
249 if has_pem {
250 return Err(SinkError::Config(anyhow!(
251 "auth.method=key_pair_file must not set `private_key_pem`"
252 )));
253 }
254 AUTH_METHOD_KEY_PAIR_FILE.to_owned()
255 }
256 Some(method) if method == AUTH_METHOD_KEY_PAIR_OBJECT => {
257 if !has_pem {
258 return Err(SinkError::Config(anyhow!(
259 "auth.method=key_pair_object requires `private_key_pem`"
260 )));
261 }
262 if has_password {
263 return Err(SinkError::Config(anyhow!(
264 "auth.method=key_pair_object must not set `password`"
265 )));
266 }
267 AUTH_METHOD_KEY_PAIR_OBJECT.to_owned()
268 }
269 Some(other) => {
270 return Err(SinkError::Config(anyhow!(
271 "invalid auth.method: {} (allowed: password | key_pair_file | key_pair_object)",
272 other
273 )));
274 }
275 None => {
276 match (has_password, has_file, has_pem) {
278 (true, false, false) => AUTH_METHOD_PASSWORD.to_owned(),
279 (false, true, false) => AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
280 (false, false, true) => AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
281 (true, true, _) | (true, _, true) | (false, true, true) => {
282 return Err(SinkError::Config(anyhow!(
283 "ambiguous auth: multiple auth options provided; remove one or set `auth.method`"
284 )));
285 }
286 _ => {
287 return Err(SinkError::Config(anyhow!(
288 "no authentication configured: set either `password`, or `private_key_file`, or `private_key_pem` (or provide `auth.method`)"
289 )));
290 }
291 }
292 }
293 };
294 config.auth_method = Some(normalized_auth_method);
295 Ok(config)
296 }
297
298 pub fn build_snowflake_task_ctx_jdbc_client(
299 &self,
300 is_append_only: bool,
301 schema: &Schema,
302 pk_indices: &Vec<usize>,
303 ) -> Result<Option<(SnowflakeTaskContext, JdbcJniClient)>> {
304 if !self.auto_schema_change && is_append_only && !self.create_table_if_not_exists {
305 return Ok(None);
307 }
308 let target_table_name = self
309 .snowflake_target_table_name
310 .clone()
311 .ok_or(SinkError::Config(anyhow!("table.name is required")))?;
312 let database = self
313 .snowflake_database
314 .clone()
315 .ok_or(SinkError::Config(anyhow!("database is required")))?;
316 let schema_name = self
317 .snowflake_schema
318 .clone()
319 .ok_or(SinkError::Config(anyhow!("schema is required")))?;
320 let mut snowflake_task_ctx = SnowflakeTaskContext {
321 target_table_name: target_table_name.clone(),
322 database,
323 schema_name,
324 schema: schema.clone(),
325 ..Default::default()
326 };
327
328 let (jdbc_url, connection_properties) = self.build_jdbc_connection_properties()?;
329 let client = JdbcJniClient::new_with_props(jdbc_url, connection_properties)?;
330
331 if self.with_s3 {
332 let stage = self
333 .stage
334 .clone()
335 .ok_or(SinkError::Config(anyhow!("stage is required")))?;
336 snowflake_task_ctx.stage = Some(stage);
337 snowflake_task_ctx.pipe_name = Some(format!("{}_pipe", target_table_name));
338 }
339 if !is_append_only {
340 let cdc_table_name = self
341 .snowflake_cdc_table_name
342 .clone()
343 .ok_or(SinkError::Config(anyhow!(
344 "intermediate.table.name is required"
345 )))?;
346 snowflake_task_ctx.cdc_table_name = Some(cdc_table_name.clone());
347 snowflake_task_ctx.schedule_seconds = self.snowflake_schedule_seconds;
348 snowflake_task_ctx.warehouse = Some(
349 self.snowflake_warehouse
350 .clone()
351 .ok_or(SinkError::Config(anyhow!("warehouse is required")))?,
352 );
353 let pk_column_names: Vec<_> = schema
354 .fields
355 .iter()
356 .enumerate()
357 .filter(|(index, _)| pk_indices.contains(index))
358 .map(|(_, field)| field.name.clone())
359 .collect();
360 if pk_column_names.is_empty() {
361 return Err(SinkError::Config(anyhow!(
362 "Primary key columns not found. Please set the `primary_key` column in the sink properties, or ensure that the sink contains the primary key columns from the upstream."
363 )));
364 }
365 snowflake_task_ctx.pk_column_names = Some(pk_column_names);
366 snowflake_task_ctx.all_column_names = Some(
367 schema
368 .fields
369 .iter()
370 .map(|field| field.name.clone())
371 .collect(),
372 );
373 snowflake_task_ctx.task_name = Some(format!(
374 "rw_snowflake_sink_from_{cdc_table_name}_to_{target_table_name}"
375 ));
376 }
377 Ok(Some((snowflake_task_ctx, client)))
378 }
379}
380
381impl EnforceSecret for SnowflakeV2Config {
382 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
383 "username",
384 "password",
385 "jdbc.url",
386 "private_key_file_pwd",
388 "private_key_pem",
389 };
390}
391
392#[derive(Clone, Debug)]
393pub struct SnowflakeV2Sink {
394 config: SnowflakeV2Config,
395 schema: Schema,
396 pk_indices: Vec<usize>,
397 is_append_only: bool,
398 param: SinkParam,
399}
400
401impl EnforceSecret for SnowflakeV2Sink {
402 fn enforce_secret<'a>(
403 prop_iter: impl Iterator<Item = &'a str>,
404 ) -> crate::sink::ConnectorResult<()> {
405 for prop in prop_iter {
406 SnowflakeV2Config::enforce_one(prop)?;
407 }
408 Ok(())
409 }
410}
411
412impl TryFrom<SinkParam> for SnowflakeV2Sink {
413 type Error = SinkError;
414
415 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
416 let schema = param.schema();
417 let config = SnowflakeV2Config::from_btreemap(¶m.properties)?;
418 let is_append_only = param.sink_type.is_append_only();
419 let pk_indices = param.downstream_pk_or_empty();
420 Ok(Self {
421 config,
422 schema,
423 pk_indices,
424 is_append_only,
425 param,
426 })
427 }
428}
429
430impl Sink for SnowflakeV2Sink {
431 type Coordinator = SnowflakeSinkCommitter;
432 type LogSinker = CoordinatedLogSinker<SnowflakeSinkWriter>;
433
434 const SINK_NAME: &'static str = SNOWFLAKE_SINK_V2;
435
436 async fn validate(&self) -> Result<()> {
437 risingwave_common::license::Feature::SnowflakeSink
438 .check_available()
439 .map_err(|e| anyhow::anyhow!(e))?;
440 if let Some((snowflake_task_ctx, client)) =
441 self.config.build_snowflake_task_ctx_jdbc_client(
442 self.is_append_only,
443 &self.schema,
444 &self.pk_indices,
445 )?
446 {
447 let client = SnowflakeJniClient::new(client, snowflake_task_ctx);
448 client.execute_create_table().await?;
449 }
450
451 Ok(())
452 }
453
454 fn support_schema_change() -> bool {
455 true
456 }
457
458 fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
459 SnowflakeV2Config::from_btreemap(config)?;
460 Ok(())
461 }
462
463 async fn new_log_sinker(
464 &self,
465 writer_param: crate::sink::SinkWriterParam,
466 ) -> Result<Self::LogSinker> {
467 let writer = SnowflakeSinkWriter::new(
468 self.config.clone(),
469 self.is_append_only,
470 writer_param.clone(),
471 self.param.clone(),
472 )
473 .await?;
474
475 let commit_checkpoint_interval =
476 NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
477 "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
478 );
479
480 CoordinatedLogSinker::new(
481 &writer_param,
482 self.param.clone(),
483 writer,
484 commit_checkpoint_interval,
485 )
486 .await
487 }
488
489 fn is_coordinated_sink(&self) -> bool {
490 true
491 }
492
493 async fn new_coordinator(
494 &self,
495 _db: DatabaseConnection,
496 _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
497 ) -> Result<Self::Coordinator> {
498 let coordinator = SnowflakeSinkCommitter::new(
499 self.config.clone(),
500 &self.schema,
501 &self.pk_indices,
502 self.is_append_only,
503 )?;
504 Ok(coordinator)
505 }
506}
507
508pub enum SnowflakeSinkWriter {
509 S3(SnowflakeRedshiftSinkS3Writer),
510 Jdbc(SnowflakeSinkJdbcWriter),
511}
512
513impl SnowflakeSinkWriter {
514 pub async fn new(
515 config: SnowflakeV2Config,
516 is_append_only: bool,
517 writer_param: SinkWriterParam,
518 param: SinkParam,
519 ) -> Result<Self> {
520 let schema = param.schema();
521 if config.with_s3 {
522 let executor_id = writer_param.executor_id;
523 let s3_writer = SnowflakeRedshiftSinkS3Writer::new(
524 config.s3_inner.ok_or_else(|| {
525 SinkError::Config(anyhow!(
526 "S3 configuration is required for Snowflake S3 sink"
527 ))
528 })?,
529 schema,
530 is_append_only,
531 executor_id,
532 config.snowflake_target_table_name,
533 )?;
534 Ok(Self::S3(s3_writer))
535 } else {
536 let jdbc_writer =
537 SnowflakeSinkJdbcWriter::new(config, is_append_only, writer_param, param).await?;
538 Ok(Self::Jdbc(jdbc_writer))
539 }
540 }
541}
542
543#[async_trait]
544impl SinkWriter for SnowflakeSinkWriter {
545 type CommitMetadata = Option<SinkMetadata>;
546
547 async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
548 match self {
549 Self::S3(writer) => writer.begin_epoch(epoch),
550 Self::Jdbc(writer) => writer.begin_epoch(epoch).await,
551 }
552 }
553
554 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
555 match self {
556 Self::S3(writer) => writer.write_batch(chunk).await,
557 Self::Jdbc(writer) => writer.write_batch(chunk).await,
558 }
559 }
560
561 async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
562 match self {
563 Self::S3(writer) => {
564 writer.barrier(is_checkpoint).await?;
565 }
566 Self::Jdbc(writer) => {
567 writer.barrier(is_checkpoint).await?;
568 }
569 }
570 Ok(Some(SinkMetadata {
571 metadata: Some(sink_metadata::Metadata::Serialized(
572 risingwave_pb::connector_service::sink_metadata::SerializedMetadata {
573 metadata: vec![],
574 },
575 )),
576 }))
577 }
578
579 async fn abort(&mut self) -> Result<()> {
580 if let Self::Jdbc(writer) = self {
581 writer.abort().await
582 } else {
583 Ok(())
584 }
585 }
586}
587
588pub struct SnowflakeSinkJdbcWriter {
589 augmented_row: AugmentedChunk,
590 jdbc_sink_writer: CoordinatedRemoteSinkWriter,
591}
592
593impl SnowflakeSinkJdbcWriter {
594 pub async fn new(
595 config: SnowflakeV2Config,
596 is_append_only: bool,
597 writer_param: SinkWriterParam,
598 mut param: SinkParam,
599 ) -> Result<Self> {
600 let metrics = SinkWriterMetrics::new(&writer_param);
601 let properties = ¶m.properties;
602 let column_descs = &mut param.columns;
603 let full_table_name = if is_append_only {
604 format!(
605 r#""{}"."{}"."{}""#,
606 config.snowflake_database.clone().unwrap_or_default(),
607 config.snowflake_schema.clone().unwrap_or_default(),
608 config
609 .snowflake_target_table_name
610 .clone()
611 .unwrap_or_default()
612 )
613 } else {
614 let max_column_id = column_descs
615 .iter()
616 .map(|column| column.column_id.get_id())
617 .max()
618 .unwrap_or(0);
619 (*column_descs).push(ColumnDesc::named(
620 SNOWFLAKE_SINK_ROW_ID,
621 ColumnId::new(max_column_id + 1),
622 DataType::Varchar,
623 ));
624 (*column_descs).push(ColumnDesc::named(
625 SNOWFLAKE_SINK_OP,
626 ColumnId::new(max_column_id + 2),
627 DataType::Int32,
628 ));
629 format!(
630 r#""{}"."{}"."{}""#,
631 config.snowflake_database.clone().unwrap_or_default(),
632 config.snowflake_schema.clone().unwrap_or_default(),
633 config.snowflake_cdc_table_name.clone().unwrap_or_default()
634 )
635 };
636 let mut new_properties = BTreeMap::from([
637 ("table.name".to_owned(), full_table_name),
638 ("connector".to_owned(), "snowflake_v2".to_owned()),
639 (
640 "jdbc.url".to_owned(),
641 config.jdbc_url.clone().unwrap_or_default(),
642 ),
643 ("type".to_owned(), "append-only".to_owned()),
644 (
645 "primary_key".to_owned(),
646 properties.get("primary_key").cloned().unwrap_or_default(),
647 ),
648 (
649 "schema.name".to_owned(),
650 config.snowflake_schema.clone().unwrap_or_default(),
651 ),
652 (
653 "database.name".to_owned(),
654 config.snowflake_database.clone().unwrap_or_default(),
655 ),
656 ]);
657
658 let (_jdbc_url, connection_properties) = config.build_jdbc_connection_properties()?;
660 for (key, value) in connection_properties {
661 new_properties.insert(key, value);
662 }
663
664 param.properties = new_properties;
665
666 let jdbc_sink_writer =
667 CoordinatedRemoteSinkWriter::new(param.clone(), metrics.clone()).await?;
668 Ok(Self {
669 augmented_row: AugmentedChunk::new(0, is_append_only),
670 jdbc_sink_writer,
671 })
672 }
673}
674
675impl SnowflakeSinkJdbcWriter {
676 async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
677 self.augmented_row.reset_epoch(epoch);
678 self.jdbc_sink_writer.begin_epoch(epoch).await?;
679 Ok(())
680 }
681
682 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
683 let chunk = self.augmented_row.augmented_chunk(chunk)?;
684 self.jdbc_sink_writer.write_batch(chunk).await?;
685 Ok(())
686 }
687
688 async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
689 self.jdbc_sink_writer.barrier(is_checkpoint).await?;
690 Ok(())
691 }
692
693 async fn abort(&mut self) -> Result<()> {
694 self.jdbc_sink_writer.abort().await?;
696 Ok(())
697 }
698}
699
700#[derive(Default)]
701pub struct SnowflakeTaskContext {
702 pub target_table_name: String,
704 pub database: String,
705 pub schema_name: String,
706 pub schema: Schema,
707
708 pub task_name: Option<String>,
710 pub cdc_table_name: Option<String>,
711 pub schedule_seconds: u64,
712 pub warehouse: Option<String>,
713 pub pk_column_names: Option<Vec<String>>,
714 pub all_column_names: Option<Vec<String>>,
715
716 pub stage: Option<String>,
718 pub pipe_name: Option<String>,
719}
720pub struct SnowflakeSinkCommitter {
721 client: Option<SnowflakeJniClient>,
722}
723
724impl SnowflakeSinkCommitter {
725 pub fn new(
726 config: SnowflakeV2Config,
727 schema: &Schema,
728 pk_indices: &Vec<usize>,
729 is_append_only: bool,
730 ) -> Result<Self> {
731 let client = if let Some((snowflake_task_ctx, client)) =
732 config.build_snowflake_task_ctx_jdbc_client(is_append_only, schema, pk_indices)?
733 {
734 Some(SnowflakeJniClient::new(client, snowflake_task_ctx))
735 } else {
736 None
737 };
738 Ok(Self { client })
739 }
740}
741
742#[async_trait]
743impl SinkCommitCoordinator for SnowflakeSinkCommitter {
744 async fn init(&mut self, _subscriber: SinkCommittedEpochSubscriber) -> Result<Option<u64>> {
745 if let Some(client) = &self.client {
746 client.execute_create_pipe().await?;
748 client.execute_create_merge_into_task().await?;
749 }
750 Ok(None)
751 }
752
753 async fn commit(
754 &mut self,
755 _epoch: u64,
756 _metadata: Vec<SinkMetadata>,
757 add_columns: Option<Vec<Field>>,
758 ) -> Result<()> {
759 let client = self.client.as_mut().ok_or_else(|| {
760 SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
761 })?;
762 client.execute_flush_pipe().await?;
763
764 if let Some(add_columns) = add_columns {
765 client
766 .execute_alter_add_columns(
767 &add_columns
768 .iter()
769 .map(|f| (f.name.clone(), f.data_type.to_string()))
770 .collect::<Vec<_>>(),
771 )
772 .await?;
773 }
774 Ok(())
775 }
776}
777
778impl Drop for SnowflakeSinkCommitter {
779 fn drop(&mut self) {
780 if let Some(client) = self.client.take() {
781 tokio::spawn(async move {
782 client.execute_drop_task().await.ok();
783 });
784 }
785 }
786}
787
788pub struct SnowflakeJniClient {
789 jdbc_client: JdbcJniClient,
790 snowflake_task_context: SnowflakeTaskContext,
791}
792
793impl SnowflakeJniClient {
794 pub fn new(jdbc_client: JdbcJniClient, snowflake_task_context: SnowflakeTaskContext) -> Self {
795 Self {
796 jdbc_client,
797 snowflake_task_context,
798 }
799 }
800
801 pub async fn execute_alter_add_columns(
802 &mut self,
803 columns: &Vec<(String, String)>,
804 ) -> Result<()> {
805 self.execute_drop_task().await?;
806 if let Some(names) = self.snowflake_task_context.all_column_names.as_mut() {
807 names.extend(columns.iter().map(|(name, _)| name.clone()));
808 }
809 if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
810 let alter_add_column_cdc_table_sql = build_alter_add_column_sql(
811 cdc_table_name,
812 &self.snowflake_task_context.database,
813 &self.snowflake_task_context.schema_name,
814 columns,
815 );
816 self.jdbc_client
817 .execute_sql_sync(vec![alter_add_column_cdc_table_sql])
818 .await?;
819 }
820
821 let alter_add_column_target_table_sql = build_alter_add_column_sql(
822 &self.snowflake_task_context.target_table_name,
823 &self.snowflake_task_context.database,
824 &self.snowflake_task_context.schema_name,
825 columns,
826 );
827 self.jdbc_client
828 .execute_sql_sync(vec![alter_add_column_target_table_sql])
829 .await?;
830
831 self.execute_create_merge_into_task().await?;
832 Ok(())
833 }
834
835 pub async fn execute_create_merge_into_task(&self) -> Result<()> {
836 if self.snowflake_task_context.task_name.is_some() {
837 let create_task_sql = build_create_merge_into_task_sql(&self.snowflake_task_context);
838 let start_task_sql = build_start_task_sql(&self.snowflake_task_context);
839 self.jdbc_client
840 .execute_sql_sync(vec![create_task_sql])
841 .await?;
842 self.jdbc_client
843 .execute_sql_sync(vec![start_task_sql])
844 .await?;
845 }
846 Ok(())
847 }
848
849 pub async fn execute_drop_task(&self) -> Result<()> {
850 if self.snowflake_task_context.task_name.is_some() {
851 let sql = build_drop_task_sql(&self.snowflake_task_context);
852 if let Err(e) = self.jdbc_client.execute_sql_sync(vec![sql]).await {
853 tracing::error!(
854 "Failed to drop Snowflake sink task {:?}: {:?}",
855 self.snowflake_task_context.task_name,
856 e.as_report()
857 );
858 } else {
859 tracing::info!(
860 "Snowflake sink task {:?} dropped",
861 self.snowflake_task_context.task_name
862 );
863 }
864 }
865 Ok(())
866 }
867
868 pub async fn execute_create_table(&self) -> Result<()> {
869 let create_target_table_sql = build_create_table_sql(
871 &self.snowflake_task_context.target_table_name,
872 &self.snowflake_task_context.database,
873 &self.snowflake_task_context.schema_name,
874 &self.snowflake_task_context.schema,
875 false,
876 )?;
877 self.jdbc_client
878 .execute_sql_sync(vec![create_target_table_sql])
879 .await?;
880 if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
881 let create_cdc_table_sql = build_create_table_sql(
882 cdc_table_name,
883 &self.snowflake_task_context.database,
884 &self.snowflake_task_context.schema_name,
885 &self.snowflake_task_context.schema,
886 true,
887 )?;
888 self.jdbc_client
889 .execute_sql_sync(vec![create_cdc_table_sql])
890 .await?;
891 }
892 Ok(())
893 }
894
895 pub async fn execute_create_pipe(&self) -> Result<()> {
896 if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
897 let table_name =
898 if let Some(table_name) = self.snowflake_task_context.cdc_table_name.as_ref() {
899 table_name
900 } else {
901 &self.snowflake_task_context.target_table_name
902 };
903 let create_pipe_sql = build_create_pipe_sql(
904 table_name,
905 &self.snowflake_task_context.database,
906 &self.snowflake_task_context.schema_name,
907 self.snowflake_task_context.stage.as_ref().ok_or_else(|| {
908 SinkError::Config(anyhow!("snowflake.stage is required for S3 writer"))
909 })?,
910 pipe_name,
911 &self.snowflake_task_context.target_table_name,
912 );
913 self.jdbc_client
914 .execute_sql_sync(vec![create_pipe_sql])
915 .await?;
916 }
917 Ok(())
918 }
919
920 pub async fn execute_flush_pipe(&self) -> Result<()> {
921 if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
922 let flush_pipe_sql = build_flush_pipe_sql(
923 &self.snowflake_task_context.database,
924 &self.snowflake_task_context.schema_name,
925 pipe_name,
926 );
927 self.jdbc_client
928 .execute_sql_sync(vec![flush_pipe_sql])
929 .await?;
930 }
931 Ok(())
932 }
933}
934
935fn build_create_table_sql(
936 table_name: &str,
937 database: &str,
938 schema_name: &str,
939 schema: &Schema,
940 need_op_and_row_id: bool,
941) -> Result<String> {
942 let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema_name, table_name);
943 let mut columns: Vec<String> = schema
944 .fields
945 .iter()
946 .map(|field| {
947 let data_type = convert_snowflake_data_type(&field.data_type)?;
948 Ok(format!(r#""{}" {}"#, field.name, data_type))
949 })
950 .collect::<Result<Vec<String>>>()?;
951 if need_op_and_row_id {
952 columns.push(format!(r#""{}" STRING"#, SNOWFLAKE_SINK_ROW_ID));
953 columns.push(format!(r#""{}" INT"#, SNOWFLAKE_SINK_OP));
954 }
955 let columns_str = columns.join(", ");
956 Ok(format!(
957 "CREATE TABLE IF NOT EXISTS {} ({}) ENABLE_SCHEMA_EVOLUTION = true",
958 full_table_name, columns_str
959 ))
960}
961
962fn convert_snowflake_data_type(data_type: &DataType) -> Result<String> {
963 let data_type = match data_type {
964 DataType::Int16 => "SMALLINT".to_owned(),
965 DataType::Int32 => "INTEGER".to_owned(),
966 DataType::Int64 => "BIGINT".to_owned(),
967 DataType::Float32 => "FLOAT4".to_owned(),
968 DataType::Float64 => "FLOAT8".to_owned(),
969 DataType::Boolean => "BOOLEAN".to_owned(),
970 DataType::Varchar => "STRING".to_owned(),
971 DataType::Date => "DATE".to_owned(),
972 DataType::Timestamp => "TIMESTAMP".to_owned(),
973 DataType::Timestamptz => "TIMESTAMP_TZ".to_owned(),
974 DataType::Jsonb => "STRING".to_owned(),
975 DataType::Decimal => "DECIMAL".to_owned(),
976 DataType::Bytea => "BINARY".to_owned(),
977 DataType::Time => "TIME".to_owned(),
978 _ => {
979 return Err(SinkError::Config(anyhow!(
980 "Dont support auto create table for datatype: {}",
981 data_type
982 )));
983 }
984 };
985 Ok(data_type)
986}
987
988fn build_create_pipe_sql(
989 table_name: &str,
990 database: &str,
991 schema: &str,
992 stage: &str,
993 pipe_name: &str,
994 target_table_name: &str,
995) -> String {
996 let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
997 let stage = format!(
998 r#""{}"."{}"."{}"/{}"#,
999 database, schema, stage, target_table_name
1000 );
1001 let table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
1002 format!(
1003 "CREATE OR REPLACE PIPE {} AUTO_INGEST = FALSE AS COPY INTO {} FROM @{} MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE FILE_FORMAT = (type = 'JSON');",
1004 pipe_name, table_name, stage
1005 )
1006}
1007
1008fn build_flush_pipe_sql(database: &str, schema: &str, pipe_name: &str) -> String {
1009 let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
1010 format!("ALTER PIPE {} REFRESH;", pipe_name,)
1011}
1012
1013fn build_alter_add_column_sql(
1014 table_name: &str,
1015 database: &str,
1016 schema: &str,
1017 columns: &Vec<(String, String)>,
1018) -> String {
1019 let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
1020 jdbc_jni_client::build_alter_add_column_sql(&full_table_name, columns, true)
1021}
1022
1023fn build_start_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1024 let SnowflakeTaskContext {
1025 task_name,
1026 database,
1027 schema_name: schema,
1028 ..
1029 } = snowflake_task_context;
1030 let full_task_name = format!(
1031 r#""{}"."{}"."{}""#,
1032 database,
1033 schema,
1034 task_name.as_ref().unwrap()
1035 );
1036 format!("ALTER TASK {} RESUME", full_task_name)
1037}
1038
1039fn build_drop_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1040 let SnowflakeTaskContext {
1041 task_name,
1042 database,
1043 schema_name: schema,
1044 ..
1045 } = snowflake_task_context;
1046 let full_task_name = format!(
1047 r#""{}"."{}"."{}""#,
1048 database,
1049 schema,
1050 task_name.as_ref().unwrap()
1051 );
1052 format!("DROP TASK IF EXISTS {}", full_task_name)
1053}
1054
1055fn build_create_merge_into_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1056 let SnowflakeTaskContext {
1057 task_name,
1058 cdc_table_name,
1059 target_table_name,
1060 schedule_seconds,
1061 warehouse,
1062 pk_column_names,
1063 all_column_names,
1064 database,
1065 schema_name,
1066 ..
1067 } = snowflake_task_context;
1068 let full_task_name = format!(
1069 r#""{}"."{}"."{}""#,
1070 database,
1071 schema_name,
1072 task_name.as_ref().unwrap()
1073 );
1074 let full_cdc_table_name = format!(
1075 r#""{}"."{}"."{}""#,
1076 database,
1077 schema_name,
1078 cdc_table_name.as_ref().unwrap()
1079 );
1080 let full_target_table_name = format!(
1081 r#""{}"."{}"."{}""#,
1082 database, schema_name, target_table_name
1083 );
1084
1085 let pk_names_str = pk_column_names
1086 .as_ref()
1087 .unwrap()
1088 .iter()
1089 .map(|name| format!(r#""{}""#, name))
1090 .collect::<Vec<String>>()
1091 .join(", ");
1092 let pk_names_eq_str = pk_column_names
1093 .as_ref()
1094 .unwrap()
1095 .iter()
1096 .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1097 .collect::<Vec<String>>()
1098 .join(" AND ");
1099 let all_column_names_set_str = all_column_names
1100 .as_ref()
1101 .unwrap()
1102 .iter()
1103 .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1104 .collect::<Vec<String>>()
1105 .join(", ");
1106 let all_column_names_str = all_column_names
1107 .as_ref()
1108 .unwrap()
1109 .iter()
1110 .map(|name| format!(r#""{}""#, name))
1111 .collect::<Vec<String>>()
1112 .join(", ");
1113 let all_column_names_insert_str = all_column_names
1114 .as_ref()
1115 .unwrap()
1116 .iter()
1117 .map(|name| format!(r#"source."{}""#, name))
1118 .collect::<Vec<String>>()
1119 .join(", ");
1120
1121 format!(
1122 r#"CREATE OR REPLACE TASK {task_name}
1123WAREHOUSE = {warehouse}
1124SCHEDULE = '{schedule_seconds} SECONDS'
1125AS
1126BEGIN
1127 LET max_row_id STRING;
1128
1129 SELECT COALESCE(MAX("{snowflake_sink_row_id}"), '0') INTO :max_row_id
1130 FROM {cdc_table_name};
1131
1132 MERGE INTO {target_table_name} AS target
1133 USING (
1134 SELECT *
1135 FROM (
1136 SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_names_str} ORDER BY "{snowflake_sink_row_id}" DESC) AS dedupe_id
1137 FROM {cdc_table_name}
1138 WHERE "{snowflake_sink_row_id}" <= :max_row_id
1139 ) AS subquery
1140 WHERE dedupe_id = 1
1141 ) AS source
1142 ON {pk_names_eq_str}
1143 WHEN MATCHED AND source."{snowflake_sink_op}" IN (2, 4) THEN DELETE
1144 WHEN MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN UPDATE SET {all_column_names_set_str}
1145 WHEN NOT MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN INSERT ({all_column_names_str}) VALUES ({all_column_names_insert_str});
1146
1147 DELETE FROM {cdc_table_name}
1148 WHERE "{snowflake_sink_row_id}" <= :max_row_id;
1149END;"#,
1150 task_name = full_task_name,
1151 warehouse = warehouse.as_ref().unwrap(),
1152 schedule_seconds = schedule_seconds,
1153 cdc_table_name = full_cdc_table_name,
1154 target_table_name = full_target_table_name,
1155 pk_names_str = pk_names_str,
1156 pk_names_eq_str = pk_names_eq_str,
1157 all_column_names_set_str = all_column_names_set_str,
1158 all_column_names_str = all_column_names_str,
1159 all_column_names_insert_str = all_column_names_insert_str,
1160 snowflake_sink_row_id = SNOWFLAKE_SINK_ROW_ID,
1161 snowflake_sink_op = SNOWFLAKE_SINK_OP,
1162 )
1163}
1164
1165#[cfg(test)]
1166mod tests {
1167 use std::collections::BTreeMap;
1168
1169 use super::*;
1170 use crate::sink::jdbc_jni_client::normalize_sql;
1171
1172 fn base_properties() -> BTreeMap<String, String> {
1173 BTreeMap::from([
1174 ("type".to_owned(), "append-only".to_owned()),
1175 ("jdbc.url".to_owned(), "jdbc:snowflake://account".to_owned()),
1176 ("username".to_owned(), "RW_USER".to_owned()),
1177 ])
1178 }
1179
1180 #[test]
1181 fn test_build_jdbc_props_password() {
1182 let mut props = base_properties();
1183 props.insert("password".to_owned(), "secret".to_owned());
1184 let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1185 let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1186 assert_eq!(url, "jdbc:snowflake://account");
1187 let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1188 assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1189 assert_eq!(map.get("password"), Some(&"secret".to_owned()));
1190 assert!(!map.contains_key("authenticator"));
1191 }
1192
1193 #[test]
1194 fn test_build_jdbc_props_key_pair_file() {
1195 let mut props = base_properties();
1196 props.insert(
1197 "auth.method".to_owned(),
1198 AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
1199 );
1200 props.insert("private_key_file".to_owned(), "/tmp/rsa_key.p8".to_owned());
1201 props.insert("private_key_file_pwd".to_owned(), "dummy".to_owned());
1202 let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1203 let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1204 assert_eq!(url, "jdbc:snowflake://account");
1205 let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1206 assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1207 assert_eq!(
1208 map.get("private_key_file"),
1209 Some(&"/tmp/rsa_key.p8".to_owned())
1210 );
1211 assert_eq!(map.get("private_key_file_pwd"), Some(&"dummy".to_owned()));
1212 }
1213
1214 #[test]
1215 fn test_build_jdbc_props_key_pair_object() {
1216 let mut props = base_properties();
1217 props.insert(
1218 "auth.method".to_owned(),
1219 AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
1220 );
1221 props.insert(
1222 "private_key_pem".to_owned(),
1223 "-----BEGIN PRIVATE KEY-----
1224...
1225-----END PRIVATE KEY-----"
1226 .to_owned(),
1227 );
1228 let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1229 let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1230 assert_eq!(url, "jdbc:snowflake://account");
1231 let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1232 assert_eq!(
1233 map.get("private_key_pem"),
1234 Some(
1235 &"-----BEGIN PRIVATE KEY-----
1236...
1237-----END PRIVATE KEY-----"
1238 .to_owned()
1239 )
1240 );
1241 assert!(!map.contains_key("private_key_file"));
1242 }
1243
1244 #[test]
1245 fn test_snowflake_sink_commit_coordinator() {
1246 let snowflake_task_context = SnowflakeTaskContext {
1247 task_name: Some("test_task".to_owned()),
1248 cdc_table_name: Some("test_cdc_table".to_owned()),
1249 target_table_name: "test_target_table".to_owned(),
1250 schedule_seconds: 3600,
1251 warehouse: Some("test_warehouse".to_owned()),
1252 pk_column_names: Some(vec!["v1".to_owned()]),
1253 all_column_names: Some(vec!["v1".to_owned(), "v2".to_owned()]),
1254 database: "test_db".to_owned(),
1255 schema_name: "test_schema".to_owned(),
1256 schema: Schema { fields: vec![] },
1257 stage: None,
1258 pipe_name: None,
1259 };
1260 let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1261 let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task"
1262WAREHOUSE = test_warehouse
1263SCHEDULE = '3600 SECONDS'
1264AS
1265BEGIN
1266 LET max_row_id STRING;
1267
1268 SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1269 FROM "test_db"."test_schema"."test_cdc_table";
1270
1271 MERGE INTO "test_db"."test_schema"."test_target_table" AS target
1272 USING (
1273 SELECT *
1274 FROM (
1275 SELECT *, ROW_NUMBER() OVER (PARTITION BY "v1" ORDER BY "__row_id" DESC) AS dedupe_id
1276 FROM "test_db"."test_schema"."test_cdc_table"
1277 WHERE "__row_id" <= :max_row_id
1278 ) AS subquery
1279 WHERE dedupe_id = 1
1280 ) AS source
1281 ON target."v1" = source."v1"
1282 WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1283 WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."v1" = source."v1", target."v2" = source."v2"
1284 WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("v1", "v2") VALUES (source."v1", source."v2");
1285
1286 DELETE FROM "test_db"."test_schema"."test_cdc_table"
1287 WHERE "__row_id" <= :max_row_id;
1288END;"#;
1289 assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1290 }
1291
1292 #[test]
1293 fn test_snowflake_sink_commit_coordinator_multi_pk() {
1294 let snowflake_task_context = SnowflakeTaskContext {
1295 task_name: Some("test_task_multi_pk".to_owned()),
1296 cdc_table_name: Some("cdc_multi_pk".to_owned()),
1297 target_table_name: "target_multi_pk".to_owned(),
1298 schedule_seconds: 300,
1299 warehouse: Some("multi_pk_warehouse".to_owned()),
1300 pk_column_names: Some(vec!["id1".to_owned(), "id2".to_owned()]),
1301 all_column_names: Some(vec!["id1".to_owned(), "id2".to_owned(), "val".to_owned()]),
1302 database: "test_db".to_owned(),
1303 schema_name: "test_schema".to_owned(),
1304 schema: Schema { fields: vec![] },
1305 stage: None,
1306 pipe_name: None,
1307 };
1308 let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1309 let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task_multi_pk"
1310WAREHOUSE = multi_pk_warehouse
1311SCHEDULE = '300 SECONDS'
1312AS
1313BEGIN
1314 LET max_row_id STRING;
1315
1316 SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1317 FROM "test_db"."test_schema"."cdc_multi_pk";
1318
1319 MERGE INTO "test_db"."test_schema"."target_multi_pk" AS target
1320 USING (
1321 SELECT *
1322 FROM (
1323 SELECT *, ROW_NUMBER() OVER (PARTITION BY "id1", "id2" ORDER BY "__row_id" DESC) AS dedupe_id
1324 FROM "test_db"."test_schema"."cdc_multi_pk"
1325 WHERE "__row_id" <= :max_row_id
1326 ) AS subquery
1327 WHERE dedupe_id = 1
1328 ) AS source
1329 ON target."id1" = source."id1" AND target."id2" = source."id2"
1330 WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1331 WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."id1" = source."id1", target."id2" = source."id2", target."val" = source."val"
1332 WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("id1", "id2", "val") VALUES (source."id1", source."id2", source."val");
1333
1334 DELETE FROM "test_db"."test_schema"."cdc_multi_pk"
1335 WHERE "__row_id" <= :max_row_id;
1336END;"#;
1337 assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1338 }
1339}