1use core::num::NonZeroU64;
16use std::collections::BTreeMap;
17
18use anyhow::anyhow;
19use phf::{Set, phf_set};
20use risingwave_common::array::StreamChunk;
21use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
22use risingwave_common::types::DataType;
23use risingwave_pb::connector_service::{SinkMetadata, sink_metadata};
24use serde::Deserialize;
25use serde_with::{DisplayFromStr, serde_as};
26use thiserror_ext::AsReport;
27use tokio::sync::mpsc::UnboundedSender;
28use tonic::async_trait;
29use with_options::WithOptions;
30
31use crate::connector_common::IcebergSinkCompactionUpdate;
32use crate::enforce_secret::EnforceSecret;
33use crate::sink::coordinate::CoordinatedLogSinker;
34use crate::sink::decouple_checkpoint_log_sink::default_commit_checkpoint_interval;
35use crate::sink::file_sink::s3::S3Common;
36use crate::sink::jdbc_jni_client::{self, JdbcJniClient};
37use crate::sink::remote::CoordinatedRemoteSinkWriter;
38use crate::sink::snowflake_redshift::{AugmentedChunk, SnowflakeRedshiftSinkS3Writer};
39use crate::sink::writer::SinkWriter;
40use crate::sink::{
41 Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT,
42 SinglePhaseCommitCoordinator, Sink, SinkCommitCoordinator, SinkError, SinkParam,
43 SinkWriterMetrics, SinkWriterParam,
44};
45
46pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2";
47pub const SNOWFLAKE_SINK_ROW_ID: &str = "__row_id";
48pub const SNOWFLAKE_SINK_OP: &str = "__op";
49
50const AUTH_METHOD_PASSWORD: &str = "password";
51const AUTH_METHOD_KEY_PAIR_FILE: &str = "key_pair_file";
52const AUTH_METHOD_KEY_PAIR_OBJECT: &str = "key_pair_object";
53const PROP_AUTH_METHOD: &str = "auth.method";
54
55#[serde_as]
56#[derive(Debug, Clone, Deserialize, WithOptions)]
57pub struct SnowflakeV2Config {
58 #[serde(rename = "type")]
59 pub r#type: String,
60
61 #[serde(rename = "intermediate.table.name")]
62 pub snowflake_cdc_table_name: Option<String>,
63
64 #[serde(rename = "table.name")]
65 pub snowflake_target_table_name: Option<String>,
66
67 #[serde(rename = "database")]
68 pub snowflake_database: Option<String>,
69
70 #[serde(rename = "schema")]
71 pub snowflake_schema: Option<String>,
72
73 #[serde(default = "default_schedule")]
74 #[serde(rename = "write.target.interval.seconds")]
75 #[serde_as(as = "DisplayFromStr")]
76 pub snowflake_schedule_seconds: u64,
77
78 #[serde(rename = "warehouse")]
79 pub snowflake_warehouse: Option<String>,
80
81 #[serde(rename = "jdbc.url")]
82 pub jdbc_url: Option<String>,
83
84 #[serde(rename = "username")]
85 pub username: Option<String>,
86
87 #[serde(rename = "password")]
88 pub password: Option<String>,
89
90 #[serde(rename = "auth.method")]
92 pub auth_method: Option<String>,
93
94 #[serde(rename = "private_key_file")]
96 pub private_key_file: Option<String>,
97
98 #[serde(rename = "private_key_file_pwd")]
99 pub private_key_file_pwd: Option<String>,
100
101 #[serde(rename = "private_key_pem")]
103 pub private_key_pem: Option<String>,
104
105 #[serde(default = "default_commit_checkpoint_interval")]
107 #[serde_as(as = "DisplayFromStr")]
108 #[with_option(allow_alter_on_fly)]
109 pub commit_checkpoint_interval: u64,
110
111 #[serde(default)]
114 #[serde(rename = "auto.schema.change")]
115 #[serde_as(as = "DisplayFromStr")]
116 pub auto_schema_change: bool,
117
118 #[serde(default)]
119 #[serde(rename = "create_table_if_not_exists")]
120 #[serde_as(as = "DisplayFromStr")]
121 pub create_table_if_not_exists: bool,
122
123 #[serde(default = "default_with_s3")]
124 #[serde(rename = "with_s3")]
125 #[serde_as(as = "DisplayFromStr")]
126 pub with_s3: bool,
127
128 #[serde(flatten)]
129 pub s3_inner: Option<S3Common>,
130
131 #[serde(rename = "stage")]
132 pub stage: Option<String>,
133}
134
135fn default_schedule() -> u64 {
136 3600 }
138
139fn default_with_s3() -> bool {
140 true
141}
142
143impl SnowflakeV2Config {
144 pub fn build_jdbc_connection_properties(&self) -> Result<(String, Vec<(String, String)>)> {
150 let jdbc_url = self
151 .jdbc_url
152 .clone()
153 .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?;
154 let username = self
155 .username
156 .clone()
157 .ok_or(SinkError::Config(anyhow!("username is required")))?;
158
159 let mut connection_properties: Vec<(String, String)> = vec![("user".to_owned(), username)];
160
161 match self.auth_method.as_deref().unwrap() {
163 AUTH_METHOD_PASSWORD => {
164 connection_properties.push(("password".to_owned(), self.password.clone().unwrap()));
166 }
167 AUTH_METHOD_KEY_PAIR_FILE => {
168 connection_properties.push((
170 "private_key_file".to_owned(),
171 self.private_key_file.clone().unwrap(),
172 ));
173 if let Some(pwd) = self.private_key_file_pwd.clone() {
174 connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
175 }
176 }
177 AUTH_METHOD_KEY_PAIR_OBJECT => {
178 connection_properties.push((
179 PROP_AUTH_METHOD.to_owned(),
180 AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
181 ));
182 connection_properties.push((
184 "private_key_pem".to_owned(),
185 self.private_key_pem.clone().unwrap(),
186 ));
187 if let Some(pwd) = self.private_key_file_pwd.clone() {
188 connection_properties.push(("private_key_file_pwd".to_owned(), pwd));
189 }
190 }
191 _ => {
192 unreachable!(
194 "Invalid auth_method - should have been caught during config validation"
195 )
196 }
197 }
198
199 Ok((jdbc_url, connection_properties))
200 }
201
202 pub fn from_btreemap(properties: &BTreeMap<String, String>) -> Result<Self> {
203 let mut config =
204 serde_json::from_value::<SnowflakeV2Config>(serde_json::to_value(properties).unwrap())
205 .map_err(|e| SinkError::Config(anyhow!(e)))?;
206 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
207 return Err(SinkError::Config(anyhow!(
208 "`{}` must be {}, or {}",
209 SINK_TYPE_OPTION,
210 SINK_TYPE_APPEND_ONLY,
211 SINK_TYPE_UPSERT
212 )));
213 }
214
215 let has_password = config.password.is_some();
217 let has_file = config.private_key_file.is_some();
218 let has_pem = config.private_key_pem.as_deref().is_some();
219
220 let normalized_auth_method = match config
221 .auth_method
222 .as_deref()
223 .map(|s| s.trim().to_ascii_lowercase())
224 {
225 Some(method) if method == AUTH_METHOD_PASSWORD => {
226 if !has_password {
227 return Err(SinkError::Config(anyhow!(
228 "auth.method=password requires `password`"
229 )));
230 }
231 if has_file || has_pem {
232 return Err(SinkError::Config(anyhow!(
233 "auth.method=password must not set `private_key_file`/`private_key_pem`"
234 )));
235 }
236 AUTH_METHOD_PASSWORD.to_owned()
237 }
238 Some(method) if method == AUTH_METHOD_KEY_PAIR_FILE => {
239 if !has_file {
240 return Err(SinkError::Config(anyhow!(
241 "auth.method=key_pair_file requires `private_key_file`"
242 )));
243 }
244 if has_password {
245 return Err(SinkError::Config(anyhow!(
246 "auth.method=key_pair_file must not set `password`"
247 )));
248 }
249 if has_pem {
250 return Err(SinkError::Config(anyhow!(
251 "auth.method=key_pair_file must not set `private_key_pem`"
252 )));
253 }
254 AUTH_METHOD_KEY_PAIR_FILE.to_owned()
255 }
256 Some(method) if method == AUTH_METHOD_KEY_PAIR_OBJECT => {
257 if !has_pem {
258 return Err(SinkError::Config(anyhow!(
259 "auth.method=key_pair_object requires `private_key_pem`"
260 )));
261 }
262 if has_password {
263 return Err(SinkError::Config(anyhow!(
264 "auth.method=key_pair_object must not set `password`"
265 )));
266 }
267 AUTH_METHOD_KEY_PAIR_OBJECT.to_owned()
268 }
269 Some(other) => {
270 return Err(SinkError::Config(anyhow!(
271 "invalid auth.method: {} (allowed: password | key_pair_file | key_pair_object)",
272 other
273 )));
274 }
275 None => {
276 match (has_password, has_file, has_pem) {
278 (true, false, false) => AUTH_METHOD_PASSWORD.to_owned(),
279 (false, true, false) => AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
280 (false, false, true) => AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
281 (true, true, _) | (true, _, true) | (false, true, true) => {
282 return Err(SinkError::Config(anyhow!(
283 "ambiguous auth: multiple auth options provided; remove one or set `auth.method`"
284 )));
285 }
286 _ => {
287 return Err(SinkError::Config(anyhow!(
288 "no authentication configured: set either `password`, or `private_key_file`, or `private_key_pem` (or provide `auth.method`)"
289 )));
290 }
291 }
292 }
293 };
294 config.auth_method = Some(normalized_auth_method);
295 Ok(config)
296 }
297
298 pub fn build_snowflake_task_ctx_jdbc_client(
299 &self,
300 is_append_only: bool,
301 schema: &Schema,
302 pk_indices: &Vec<usize>,
303 ) -> Result<Option<(SnowflakeTaskContext, JdbcJniClient)>> {
304 if !self.auto_schema_change && is_append_only && !self.create_table_if_not_exists {
305 return Ok(None);
307 }
308 let target_table_name = self
309 .snowflake_target_table_name
310 .clone()
311 .ok_or(SinkError::Config(anyhow!("table.name is required")))?;
312 let database = self
313 .snowflake_database
314 .clone()
315 .ok_or(SinkError::Config(anyhow!("database is required")))?;
316 let schema_name = self
317 .snowflake_schema
318 .clone()
319 .ok_or(SinkError::Config(anyhow!("schema is required")))?;
320 let mut snowflake_task_ctx = SnowflakeTaskContext {
321 target_table_name: target_table_name.clone(),
322 database,
323 schema_name,
324 schema: schema.clone(),
325 ..Default::default()
326 };
327
328 let (jdbc_url, connection_properties) = self.build_jdbc_connection_properties()?;
329 let client = JdbcJniClient::new_with_props(jdbc_url, connection_properties)?;
330
331 if self.with_s3 {
332 let stage = self
333 .stage
334 .clone()
335 .ok_or(SinkError::Config(anyhow!("stage is required")))?;
336 snowflake_task_ctx.stage = Some(stage);
337 snowflake_task_ctx.pipe_name = Some(format!("{}_pipe", target_table_name));
338 }
339 if !is_append_only {
340 let cdc_table_name = self
341 .snowflake_cdc_table_name
342 .clone()
343 .ok_or(SinkError::Config(anyhow!(
344 "intermediate.table.name is required"
345 )))?;
346 snowflake_task_ctx.cdc_table_name = Some(cdc_table_name.clone());
347 snowflake_task_ctx.schedule_seconds = self.snowflake_schedule_seconds;
348 snowflake_task_ctx.warehouse = Some(
349 self.snowflake_warehouse
350 .clone()
351 .ok_or(SinkError::Config(anyhow!("warehouse is required")))?,
352 );
353 let pk_column_names: Vec<_> = schema
354 .fields
355 .iter()
356 .enumerate()
357 .filter(|(index, _)| pk_indices.contains(index))
358 .map(|(_, field)| field.name.clone())
359 .collect();
360 if pk_column_names.is_empty() {
361 return Err(SinkError::Config(anyhow!(
362 "Primary key columns not found. Please set the `primary_key` column in the sink properties, or ensure that the sink contains the primary key columns from the upstream."
363 )));
364 }
365 snowflake_task_ctx.pk_column_names = Some(pk_column_names);
366 snowflake_task_ctx.all_column_names = Some(
367 schema
368 .fields
369 .iter()
370 .map(|field| field.name.clone())
371 .collect(),
372 );
373 snowflake_task_ctx.task_name = Some(format!(
374 "rw_snowflake_sink_from_{cdc_table_name}_to_{target_table_name}"
375 ));
376 }
377 Ok(Some((snowflake_task_ctx, client)))
378 }
379}
380
381impl EnforceSecret for SnowflakeV2Config {
382 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
383 "username",
384 "password",
385 "jdbc.url",
386 "private_key_file_pwd",
388 "private_key_pem",
389 };
390}
391
392#[derive(Clone, Debug)]
393pub struct SnowflakeV2Sink {
394 config: SnowflakeV2Config,
395 schema: Schema,
396 pk_indices: Vec<usize>,
397 is_append_only: bool,
398 param: SinkParam,
399}
400
401impl EnforceSecret for SnowflakeV2Sink {
402 fn enforce_secret<'a>(
403 prop_iter: impl Iterator<Item = &'a str>,
404 ) -> crate::sink::ConnectorResult<()> {
405 for prop in prop_iter {
406 SnowflakeV2Config::enforce_one(prop)?;
407 }
408 Ok(())
409 }
410}
411
412impl TryFrom<SinkParam> for SnowflakeV2Sink {
413 type Error = SinkError;
414
415 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
416 let schema = param.schema();
417 let config = SnowflakeV2Config::from_btreemap(¶m.properties)?;
418 let is_append_only = param.sink_type.is_append_only();
419 let pk_indices = param.downstream_pk_or_empty();
420 Ok(Self {
421 config,
422 schema,
423 pk_indices,
424 is_append_only,
425 param,
426 })
427 }
428}
429
430impl Sink for SnowflakeV2Sink {
431 type LogSinker = CoordinatedLogSinker<SnowflakeSinkWriter>;
432
433 const SINK_NAME: &'static str = SNOWFLAKE_SINK_V2;
434
435 async fn validate(&self) -> Result<()> {
436 risingwave_common::license::Feature::SnowflakeSink
437 .check_available()
438 .map_err(|e| anyhow::anyhow!(e))?;
439 if let Some((snowflake_task_ctx, client)) =
440 self.config.build_snowflake_task_ctx_jdbc_client(
441 self.is_append_only,
442 &self.schema,
443 &self.pk_indices,
444 )?
445 {
446 let client = SnowflakeJniClient::new(client, snowflake_task_ctx);
447 client.execute_create_table().await?;
448 }
449
450 Ok(())
451 }
452
453 fn support_schema_change() -> bool {
454 true
455 }
456
457 fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
458 SnowflakeV2Config::from_btreemap(config)?;
459 Ok(())
460 }
461
462 async fn new_log_sinker(
463 &self,
464 writer_param: crate::sink::SinkWriterParam,
465 ) -> Result<Self::LogSinker> {
466 let writer = SnowflakeSinkWriter::new(
467 self.config.clone(),
468 self.is_append_only,
469 writer_param.clone(),
470 self.param.clone(),
471 )
472 .await?;
473
474 let commit_checkpoint_interval =
475 NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
476 "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
477 );
478
479 CoordinatedLogSinker::new(
480 &writer_param,
481 self.param.clone(),
482 writer,
483 commit_checkpoint_interval,
484 )
485 .await
486 }
487
488 fn is_coordinated_sink(&self) -> bool {
489 true
490 }
491
492 async fn new_coordinator(
493 &self,
494 _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
495 ) -> Result<SinkCommitCoordinator> {
496 let coordinator = SnowflakeSinkCommitter::new(
497 self.config.clone(),
498 &self.schema,
499 &self.pk_indices,
500 self.is_append_only,
501 )?;
502 Ok(SinkCommitCoordinator::SinglePhase(Box::new(coordinator)))
503 }
504}
505
506pub enum SnowflakeSinkWriter {
507 S3(SnowflakeRedshiftSinkS3Writer),
508 Jdbc(SnowflakeSinkJdbcWriter),
509}
510
511impl SnowflakeSinkWriter {
512 pub async fn new(
513 config: SnowflakeV2Config,
514 is_append_only: bool,
515 writer_param: SinkWriterParam,
516 param: SinkParam,
517 ) -> Result<Self> {
518 let schema = param.schema();
519 if config.with_s3 {
520 let executor_id = writer_param.executor_id;
521 let s3_writer = SnowflakeRedshiftSinkS3Writer::new(
522 config.s3_inner.ok_or_else(|| {
523 SinkError::Config(anyhow!(
524 "S3 configuration is required for Snowflake S3 sink"
525 ))
526 })?,
527 schema,
528 is_append_only,
529 executor_id,
530 config.snowflake_target_table_name,
531 )?;
532 Ok(Self::S3(s3_writer))
533 } else {
534 let jdbc_writer =
535 SnowflakeSinkJdbcWriter::new(config, is_append_only, writer_param, param).await?;
536 Ok(Self::Jdbc(jdbc_writer))
537 }
538 }
539}
540
541#[async_trait]
542impl SinkWriter for SnowflakeSinkWriter {
543 type CommitMetadata = Option<SinkMetadata>;
544
545 async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
546 match self {
547 Self::S3(writer) => writer.begin_epoch(epoch),
548 Self::Jdbc(writer) => writer.begin_epoch(epoch).await,
549 }
550 }
551
552 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
553 match self {
554 Self::S3(writer) => writer.write_batch(chunk).await,
555 Self::Jdbc(writer) => writer.write_batch(chunk).await,
556 }
557 }
558
559 async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
560 match self {
561 Self::S3(writer) => {
562 writer.barrier(is_checkpoint).await?;
563 }
564 Self::Jdbc(writer) => {
565 writer.barrier(is_checkpoint).await?;
566 }
567 }
568 Ok(Some(SinkMetadata {
569 metadata: Some(sink_metadata::Metadata::Serialized(
570 risingwave_pb::connector_service::sink_metadata::SerializedMetadata {
571 metadata: vec![],
572 },
573 )),
574 }))
575 }
576
577 async fn abort(&mut self) -> Result<()> {
578 if let Self::Jdbc(writer) = self {
579 writer.abort().await
580 } else {
581 Ok(())
582 }
583 }
584}
585
586pub struct SnowflakeSinkJdbcWriter {
587 augmented_row: AugmentedChunk,
588 jdbc_sink_writer: CoordinatedRemoteSinkWriter,
589}
590
591impl SnowflakeSinkJdbcWriter {
592 pub async fn new(
593 config: SnowflakeV2Config,
594 is_append_only: bool,
595 writer_param: SinkWriterParam,
596 mut param: SinkParam,
597 ) -> Result<Self> {
598 let metrics = SinkWriterMetrics::new(&writer_param);
599 let properties = ¶m.properties;
600 let column_descs = &mut param.columns;
601 let full_table_name = if is_append_only {
602 format!(
603 r#""{}"."{}"."{}""#,
604 config.snowflake_database.clone().unwrap_or_default(),
605 config.snowflake_schema.clone().unwrap_or_default(),
606 config
607 .snowflake_target_table_name
608 .clone()
609 .unwrap_or_default()
610 )
611 } else {
612 let max_column_id = column_descs
613 .iter()
614 .map(|column| column.column_id.get_id())
615 .max()
616 .unwrap_or(0);
617 (*column_descs).push(ColumnDesc::named(
618 SNOWFLAKE_SINK_ROW_ID,
619 ColumnId::new(max_column_id + 1),
620 DataType::Varchar,
621 ));
622 (*column_descs).push(ColumnDesc::named(
623 SNOWFLAKE_SINK_OP,
624 ColumnId::new(max_column_id + 2),
625 DataType::Int32,
626 ));
627 format!(
628 r#""{}"."{}"."{}""#,
629 config.snowflake_database.clone().unwrap_or_default(),
630 config.snowflake_schema.clone().unwrap_or_default(),
631 config.snowflake_cdc_table_name.clone().unwrap_or_default()
632 )
633 };
634 let mut new_properties = BTreeMap::from([
635 ("table.name".to_owned(), full_table_name),
636 ("connector".to_owned(), "snowflake_v2".to_owned()),
637 (
638 "jdbc.url".to_owned(),
639 config.jdbc_url.clone().unwrap_or_default(),
640 ),
641 ("type".to_owned(), "append-only".to_owned()),
642 (
643 "primary_key".to_owned(),
644 properties.get("primary_key").cloned().unwrap_or_default(),
645 ),
646 (
647 "schema.name".to_owned(),
648 config.snowflake_schema.clone().unwrap_or_default(),
649 ),
650 (
651 "database.name".to_owned(),
652 config.snowflake_database.clone().unwrap_or_default(),
653 ),
654 ]);
655
656 let (_jdbc_url, connection_properties) = config.build_jdbc_connection_properties()?;
658 for (key, value) in connection_properties {
659 new_properties.insert(key, value);
660 }
661
662 param.properties = new_properties;
663
664 let jdbc_sink_writer =
665 CoordinatedRemoteSinkWriter::new(param.clone(), metrics.clone()).await?;
666 Ok(Self {
667 augmented_row: AugmentedChunk::new(0, is_append_only),
668 jdbc_sink_writer,
669 })
670 }
671}
672
673impl SnowflakeSinkJdbcWriter {
674 async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
675 self.augmented_row.reset_epoch(epoch);
676 self.jdbc_sink_writer.begin_epoch(epoch).await?;
677 Ok(())
678 }
679
680 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
681 let chunk = self.augmented_row.augmented_chunk(chunk)?;
682 self.jdbc_sink_writer.write_batch(chunk).await?;
683 Ok(())
684 }
685
686 async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
687 self.jdbc_sink_writer.barrier(is_checkpoint).await?;
688 Ok(())
689 }
690
691 async fn abort(&mut self) -> Result<()> {
692 self.jdbc_sink_writer.abort().await?;
694 Ok(())
695 }
696}
697
698#[derive(Default)]
699pub struct SnowflakeTaskContext {
700 pub target_table_name: String,
702 pub database: String,
703 pub schema_name: String,
704 pub schema: Schema,
705
706 pub task_name: Option<String>,
708 pub cdc_table_name: Option<String>,
709 pub schedule_seconds: u64,
710 pub warehouse: Option<String>,
711 pub pk_column_names: Option<Vec<String>>,
712 pub all_column_names: Option<Vec<String>>,
713
714 pub stage: Option<String>,
716 pub pipe_name: Option<String>,
717}
718pub struct SnowflakeSinkCommitter {
719 client: Option<SnowflakeJniClient>,
720}
721
722impl SnowflakeSinkCommitter {
723 pub fn new(
724 config: SnowflakeV2Config,
725 schema: &Schema,
726 pk_indices: &Vec<usize>,
727 is_append_only: bool,
728 ) -> Result<Self> {
729 let client = if let Some((snowflake_task_ctx, client)) =
730 config.build_snowflake_task_ctx_jdbc_client(is_append_only, schema, pk_indices)?
731 {
732 Some(SnowflakeJniClient::new(client, snowflake_task_ctx))
733 } else {
734 None
735 };
736 Ok(Self { client })
737 }
738}
739
740#[async_trait]
741impl SinglePhaseCommitCoordinator for SnowflakeSinkCommitter {
742 async fn init(&mut self) -> Result<()> {
743 if let Some(client) = &self.client {
744 client.execute_create_pipe().await?;
746 client.execute_create_merge_into_task().await?;
747 }
748 Ok(())
749 }
750
751 async fn commit(
752 &mut self,
753 _epoch: u64,
754 _metadata: Vec<SinkMetadata>,
755 add_columns: Option<Vec<Field>>,
756 ) -> Result<()> {
757 let client = self.client.as_mut().ok_or_else(|| {
758 SinkError::Config(anyhow!("Snowflake sink committer is not initialized."))
759 })?;
760 client.execute_flush_pipe().await?;
761
762 if let Some(add_columns) = add_columns {
763 client
764 .execute_alter_add_columns(
765 &add_columns
766 .iter()
767 .map(|f| (f.name.clone(), f.data_type.to_string()))
768 .collect::<Vec<_>>(),
769 )
770 .await?;
771 }
772 Ok(())
773 }
774}
775
776impl Drop for SnowflakeSinkCommitter {
777 fn drop(&mut self) {
778 if let Some(client) = self.client.take() {
779 tokio::spawn(async move {
780 client.execute_drop_task().await.ok();
781 });
782 }
783 }
784}
785
786pub struct SnowflakeJniClient {
787 jdbc_client: JdbcJniClient,
788 snowflake_task_context: SnowflakeTaskContext,
789}
790
791impl SnowflakeJniClient {
792 pub fn new(jdbc_client: JdbcJniClient, snowflake_task_context: SnowflakeTaskContext) -> Self {
793 Self {
794 jdbc_client,
795 snowflake_task_context,
796 }
797 }
798
799 pub async fn execute_alter_add_columns(
800 &mut self,
801 columns: &Vec<(String, String)>,
802 ) -> Result<()> {
803 self.execute_drop_task().await?;
804 if let Some(names) = self.snowflake_task_context.all_column_names.as_mut() {
805 names.extend(columns.iter().map(|(name, _)| name.clone()));
806 }
807 if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
808 let alter_add_column_cdc_table_sql = build_alter_add_column_sql(
809 cdc_table_name,
810 &self.snowflake_task_context.database,
811 &self.snowflake_task_context.schema_name,
812 columns,
813 );
814 self.jdbc_client
815 .execute_sql_sync(vec![alter_add_column_cdc_table_sql])
816 .await?;
817 }
818
819 let alter_add_column_target_table_sql = build_alter_add_column_sql(
820 &self.snowflake_task_context.target_table_name,
821 &self.snowflake_task_context.database,
822 &self.snowflake_task_context.schema_name,
823 columns,
824 );
825 self.jdbc_client
826 .execute_sql_sync(vec![alter_add_column_target_table_sql])
827 .await?;
828
829 self.execute_create_merge_into_task().await?;
830 Ok(())
831 }
832
833 pub async fn execute_create_merge_into_task(&self) -> Result<()> {
834 if self.snowflake_task_context.task_name.is_some() {
835 let create_task_sql = build_create_merge_into_task_sql(&self.snowflake_task_context);
836 let start_task_sql = build_start_task_sql(&self.snowflake_task_context);
837 self.jdbc_client
838 .execute_sql_sync(vec![create_task_sql])
839 .await?;
840 self.jdbc_client
841 .execute_sql_sync(vec![start_task_sql])
842 .await?;
843 }
844 Ok(())
845 }
846
847 pub async fn execute_drop_task(&self) -> Result<()> {
848 if self.snowflake_task_context.task_name.is_some() {
849 let sql = build_drop_task_sql(&self.snowflake_task_context);
850 if let Err(e) = self.jdbc_client.execute_sql_sync(vec![sql]).await {
851 tracing::error!(
852 "Failed to drop Snowflake sink task {:?}: {:?}",
853 self.snowflake_task_context.task_name,
854 e.as_report()
855 );
856 } else {
857 tracing::info!(
858 "Snowflake sink task {:?} dropped",
859 self.snowflake_task_context.task_name
860 );
861 }
862 }
863 Ok(())
864 }
865
866 pub async fn execute_create_table(&self) -> Result<()> {
867 let create_target_table_sql = build_create_table_sql(
869 &self.snowflake_task_context.target_table_name,
870 &self.snowflake_task_context.database,
871 &self.snowflake_task_context.schema_name,
872 &self.snowflake_task_context.schema,
873 false,
874 )?;
875 self.jdbc_client
876 .execute_sql_sync(vec![create_target_table_sql])
877 .await?;
878 if let Some(cdc_table_name) = &self.snowflake_task_context.cdc_table_name {
879 let create_cdc_table_sql = build_create_table_sql(
880 cdc_table_name,
881 &self.snowflake_task_context.database,
882 &self.snowflake_task_context.schema_name,
883 &self.snowflake_task_context.schema,
884 true,
885 )?;
886 self.jdbc_client
887 .execute_sql_sync(vec![create_cdc_table_sql])
888 .await?;
889 }
890 Ok(())
891 }
892
893 pub async fn execute_create_pipe(&self) -> Result<()> {
894 if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
895 let table_name =
896 if let Some(table_name) = self.snowflake_task_context.cdc_table_name.as_ref() {
897 table_name
898 } else {
899 &self.snowflake_task_context.target_table_name
900 };
901 let create_pipe_sql = build_create_pipe_sql(
902 table_name,
903 &self.snowflake_task_context.database,
904 &self.snowflake_task_context.schema_name,
905 self.snowflake_task_context.stage.as_ref().ok_or_else(|| {
906 SinkError::Config(anyhow!("snowflake.stage is required for S3 writer"))
907 })?,
908 pipe_name,
909 &self.snowflake_task_context.target_table_name,
910 );
911 self.jdbc_client
912 .execute_sql_sync(vec![create_pipe_sql])
913 .await?;
914 }
915 Ok(())
916 }
917
918 pub async fn execute_flush_pipe(&self) -> Result<()> {
919 if let Some(pipe_name) = &self.snowflake_task_context.pipe_name {
920 let flush_pipe_sql = build_flush_pipe_sql(
921 &self.snowflake_task_context.database,
922 &self.snowflake_task_context.schema_name,
923 pipe_name,
924 );
925 self.jdbc_client
926 .execute_sql_sync(vec![flush_pipe_sql])
927 .await?;
928 }
929 Ok(())
930 }
931}
932
933fn build_create_table_sql(
934 table_name: &str,
935 database: &str,
936 schema_name: &str,
937 schema: &Schema,
938 need_op_and_row_id: bool,
939) -> Result<String> {
940 let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema_name, table_name);
941 let mut columns: Vec<String> = schema
942 .fields
943 .iter()
944 .map(|field| {
945 let data_type = convert_snowflake_data_type(&field.data_type)?;
946 Ok(format!(r#""{}" {}"#, field.name, data_type))
947 })
948 .collect::<Result<Vec<String>>>()?;
949 if need_op_and_row_id {
950 columns.push(format!(r#""{}" STRING"#, SNOWFLAKE_SINK_ROW_ID));
951 columns.push(format!(r#""{}" INT"#, SNOWFLAKE_SINK_OP));
952 }
953 let columns_str = columns.join(", ");
954 Ok(format!(
955 "CREATE TABLE IF NOT EXISTS {} ({}) ENABLE_SCHEMA_EVOLUTION = true",
956 full_table_name, columns_str
957 ))
958}
959
960fn convert_snowflake_data_type(data_type: &DataType) -> Result<String> {
961 let data_type = match data_type {
962 DataType::Int16 => "SMALLINT".to_owned(),
963 DataType::Int32 => "INTEGER".to_owned(),
964 DataType::Int64 => "BIGINT".to_owned(),
965 DataType::Float32 => "FLOAT4".to_owned(),
966 DataType::Float64 => "FLOAT8".to_owned(),
967 DataType::Boolean => "BOOLEAN".to_owned(),
968 DataType::Varchar => "STRING".to_owned(),
969 DataType::Date => "DATE".to_owned(),
970 DataType::Timestamp => "TIMESTAMP".to_owned(),
971 DataType::Timestamptz => "TIMESTAMP_TZ".to_owned(),
972 DataType::Jsonb => "STRING".to_owned(),
973 DataType::Decimal => "DECIMAL".to_owned(),
974 DataType::Bytea => "BINARY".to_owned(),
975 DataType::Time => "TIME".to_owned(),
976 _ => {
977 return Err(SinkError::Config(anyhow!(
978 "Dont support auto create table for datatype: {}",
979 data_type
980 )));
981 }
982 };
983 Ok(data_type)
984}
985
986fn build_create_pipe_sql(
987 table_name: &str,
988 database: &str,
989 schema: &str,
990 stage: &str,
991 pipe_name: &str,
992 target_table_name: &str,
993) -> String {
994 let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
995 let stage = format!(
996 r#""{}"."{}"."{}"/{}"#,
997 database, schema, stage, target_table_name
998 );
999 let table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
1000 format!(
1001 "CREATE OR REPLACE PIPE {} AUTO_INGEST = FALSE AS COPY INTO {} FROM @{} MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE FILE_FORMAT = (type = 'JSON');",
1002 pipe_name, table_name, stage
1003 )
1004}
1005
1006fn build_flush_pipe_sql(database: &str, schema: &str, pipe_name: &str) -> String {
1007 let pipe_name = format!(r#""{}"."{}"."{}""#, database, schema, pipe_name);
1008 format!("ALTER PIPE {} REFRESH;", pipe_name,)
1009}
1010
1011fn build_alter_add_column_sql(
1012 table_name: &str,
1013 database: &str,
1014 schema: &str,
1015 columns: &Vec<(String, String)>,
1016) -> String {
1017 let full_table_name = format!(r#""{}"."{}"."{}""#, database, schema, table_name);
1018 jdbc_jni_client::build_alter_add_column_sql(&full_table_name, columns, true)
1019}
1020
1021fn build_start_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1022 let SnowflakeTaskContext {
1023 task_name,
1024 database,
1025 schema_name: schema,
1026 ..
1027 } = snowflake_task_context;
1028 let full_task_name = format!(
1029 r#""{}"."{}"."{}""#,
1030 database,
1031 schema,
1032 task_name.as_ref().unwrap()
1033 );
1034 format!("ALTER TASK {} RESUME", full_task_name)
1035}
1036
1037fn build_drop_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1038 let SnowflakeTaskContext {
1039 task_name,
1040 database,
1041 schema_name: schema,
1042 ..
1043 } = snowflake_task_context;
1044 let full_task_name = format!(
1045 r#""{}"."{}"."{}""#,
1046 database,
1047 schema,
1048 task_name.as_ref().unwrap()
1049 );
1050 format!("DROP TASK IF EXISTS {}", full_task_name)
1051}
1052
1053fn build_create_merge_into_task_sql(snowflake_task_context: &SnowflakeTaskContext) -> String {
1054 let SnowflakeTaskContext {
1055 task_name,
1056 cdc_table_name,
1057 target_table_name,
1058 schedule_seconds,
1059 warehouse,
1060 pk_column_names,
1061 all_column_names,
1062 database,
1063 schema_name,
1064 ..
1065 } = snowflake_task_context;
1066 let full_task_name = format!(
1067 r#""{}"."{}"."{}""#,
1068 database,
1069 schema_name,
1070 task_name.as_ref().unwrap()
1071 );
1072 let full_cdc_table_name = format!(
1073 r#""{}"."{}"."{}""#,
1074 database,
1075 schema_name,
1076 cdc_table_name.as_ref().unwrap()
1077 );
1078 let full_target_table_name = format!(
1079 r#""{}"."{}"."{}""#,
1080 database, schema_name, target_table_name
1081 );
1082
1083 let pk_names_str = pk_column_names
1084 .as_ref()
1085 .unwrap()
1086 .iter()
1087 .map(|name| format!(r#""{}""#, name))
1088 .collect::<Vec<String>>()
1089 .join(", ");
1090 let pk_names_eq_str = pk_column_names
1091 .as_ref()
1092 .unwrap()
1093 .iter()
1094 .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1095 .collect::<Vec<String>>()
1096 .join(" AND ");
1097 let all_column_names_set_str = all_column_names
1098 .as_ref()
1099 .unwrap()
1100 .iter()
1101 .map(|name| format!(r#"target."{}" = source."{}""#, name, name))
1102 .collect::<Vec<String>>()
1103 .join(", ");
1104 let all_column_names_str = all_column_names
1105 .as_ref()
1106 .unwrap()
1107 .iter()
1108 .map(|name| format!(r#""{}""#, name))
1109 .collect::<Vec<String>>()
1110 .join(", ");
1111 let all_column_names_insert_str = all_column_names
1112 .as_ref()
1113 .unwrap()
1114 .iter()
1115 .map(|name| format!(r#"source."{}""#, name))
1116 .collect::<Vec<String>>()
1117 .join(", ");
1118
1119 format!(
1120 r#"CREATE OR REPLACE TASK {task_name}
1121WAREHOUSE = {warehouse}
1122SCHEDULE = '{schedule_seconds} SECONDS'
1123AS
1124BEGIN
1125 LET max_row_id STRING;
1126
1127 SELECT COALESCE(MAX("{snowflake_sink_row_id}"), '0') INTO :max_row_id
1128 FROM {cdc_table_name};
1129
1130 MERGE INTO {target_table_name} AS target
1131 USING (
1132 SELECT *
1133 FROM (
1134 SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_names_str} ORDER BY "{snowflake_sink_row_id}" DESC) AS dedupe_id
1135 FROM {cdc_table_name}
1136 WHERE "{snowflake_sink_row_id}" <= :max_row_id
1137 ) AS subquery
1138 WHERE dedupe_id = 1
1139 ) AS source
1140 ON {pk_names_eq_str}
1141 WHEN MATCHED AND source."{snowflake_sink_op}" IN (2, 4) THEN DELETE
1142 WHEN MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN UPDATE SET {all_column_names_set_str}
1143 WHEN NOT MATCHED AND source."{snowflake_sink_op}" IN (1, 3) THEN INSERT ({all_column_names_str}) VALUES ({all_column_names_insert_str});
1144
1145 DELETE FROM {cdc_table_name}
1146 WHERE "{snowflake_sink_row_id}" <= :max_row_id;
1147END;"#,
1148 task_name = full_task_name,
1149 warehouse = warehouse.as_ref().unwrap(),
1150 schedule_seconds = schedule_seconds,
1151 cdc_table_name = full_cdc_table_name,
1152 target_table_name = full_target_table_name,
1153 pk_names_str = pk_names_str,
1154 pk_names_eq_str = pk_names_eq_str,
1155 all_column_names_set_str = all_column_names_set_str,
1156 all_column_names_str = all_column_names_str,
1157 all_column_names_insert_str = all_column_names_insert_str,
1158 snowflake_sink_row_id = SNOWFLAKE_SINK_ROW_ID,
1159 snowflake_sink_op = SNOWFLAKE_SINK_OP,
1160 )
1161}
1162
1163#[cfg(test)]
1164mod tests {
1165 use std::collections::BTreeMap;
1166
1167 use super::*;
1168 use crate::sink::jdbc_jni_client::normalize_sql;
1169
1170 fn base_properties() -> BTreeMap<String, String> {
1171 BTreeMap::from([
1172 ("type".to_owned(), "append-only".to_owned()),
1173 ("jdbc.url".to_owned(), "jdbc:snowflake://account".to_owned()),
1174 ("username".to_owned(), "RW_USER".to_owned()),
1175 ])
1176 }
1177
1178 #[test]
1179 fn test_build_jdbc_props_password() {
1180 let mut props = base_properties();
1181 props.insert("password".to_owned(), "secret".to_owned());
1182 let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1183 let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1184 assert_eq!(url, "jdbc:snowflake://account");
1185 let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1186 assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1187 assert_eq!(map.get("password"), Some(&"secret".to_owned()));
1188 assert!(!map.contains_key("authenticator"));
1189 }
1190
1191 #[test]
1192 fn test_build_jdbc_props_key_pair_file() {
1193 let mut props = base_properties();
1194 props.insert(
1195 "auth.method".to_owned(),
1196 AUTH_METHOD_KEY_PAIR_FILE.to_owned(),
1197 );
1198 props.insert("private_key_file".to_owned(), "/tmp/rsa_key.p8".to_owned());
1199 props.insert("private_key_file_pwd".to_owned(), "dummy".to_owned());
1200 let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1201 let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1202 assert_eq!(url, "jdbc:snowflake://account");
1203 let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1204 assert_eq!(map.get("user"), Some(&"RW_USER".to_owned()));
1205 assert_eq!(
1206 map.get("private_key_file"),
1207 Some(&"/tmp/rsa_key.p8".to_owned())
1208 );
1209 assert_eq!(map.get("private_key_file_pwd"), Some(&"dummy".to_owned()));
1210 }
1211
1212 #[test]
1213 fn test_build_jdbc_props_key_pair_object() {
1214 let mut props = base_properties();
1215 props.insert(
1216 "auth.method".to_owned(),
1217 AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(),
1218 );
1219 props.insert(
1220 "private_key_pem".to_owned(),
1221 "-----BEGIN PRIVATE KEY-----
1222...
1223-----END PRIVATE KEY-----"
1224 .to_owned(),
1225 );
1226 let config = SnowflakeV2Config::from_btreemap(&props).unwrap();
1227 let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap();
1228 assert_eq!(url, "jdbc:snowflake://account");
1229 let map: BTreeMap<_, _> = connection_properties.into_iter().collect();
1230 assert_eq!(
1231 map.get("private_key_pem"),
1232 Some(
1233 &"-----BEGIN PRIVATE KEY-----
1234...
1235-----END PRIVATE KEY-----"
1236 .to_owned()
1237 )
1238 );
1239 assert!(!map.contains_key("private_key_file"));
1240 }
1241
1242 #[test]
1243 fn test_snowflake_sink_commit_coordinator() {
1244 let snowflake_task_context = SnowflakeTaskContext {
1245 task_name: Some("test_task".to_owned()),
1246 cdc_table_name: Some("test_cdc_table".to_owned()),
1247 target_table_name: "test_target_table".to_owned(),
1248 schedule_seconds: 3600,
1249 warehouse: Some("test_warehouse".to_owned()),
1250 pk_column_names: Some(vec!["v1".to_owned()]),
1251 all_column_names: Some(vec!["v1".to_owned(), "v2".to_owned()]),
1252 database: "test_db".to_owned(),
1253 schema_name: "test_schema".to_owned(),
1254 schema: Schema { fields: vec![] },
1255 stage: None,
1256 pipe_name: None,
1257 };
1258 let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1259 let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task"
1260WAREHOUSE = test_warehouse
1261SCHEDULE = '3600 SECONDS'
1262AS
1263BEGIN
1264 LET max_row_id STRING;
1265
1266 SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1267 FROM "test_db"."test_schema"."test_cdc_table";
1268
1269 MERGE INTO "test_db"."test_schema"."test_target_table" AS target
1270 USING (
1271 SELECT *
1272 FROM (
1273 SELECT *, ROW_NUMBER() OVER (PARTITION BY "v1" ORDER BY "__row_id" DESC) AS dedupe_id
1274 FROM "test_db"."test_schema"."test_cdc_table"
1275 WHERE "__row_id" <= :max_row_id
1276 ) AS subquery
1277 WHERE dedupe_id = 1
1278 ) AS source
1279 ON target."v1" = source."v1"
1280 WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1281 WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."v1" = source."v1", target."v2" = source."v2"
1282 WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("v1", "v2") VALUES (source."v1", source."v2");
1283
1284 DELETE FROM "test_db"."test_schema"."test_cdc_table"
1285 WHERE "__row_id" <= :max_row_id;
1286END;"#;
1287 assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1288 }
1289
1290 #[test]
1291 fn test_snowflake_sink_commit_coordinator_multi_pk() {
1292 let snowflake_task_context = SnowflakeTaskContext {
1293 task_name: Some("test_task_multi_pk".to_owned()),
1294 cdc_table_name: Some("cdc_multi_pk".to_owned()),
1295 target_table_name: "target_multi_pk".to_owned(),
1296 schedule_seconds: 300,
1297 warehouse: Some("multi_pk_warehouse".to_owned()),
1298 pk_column_names: Some(vec!["id1".to_owned(), "id2".to_owned()]),
1299 all_column_names: Some(vec!["id1".to_owned(), "id2".to_owned(), "val".to_owned()]),
1300 database: "test_db".to_owned(),
1301 schema_name: "test_schema".to_owned(),
1302 schema: Schema { fields: vec![] },
1303 stage: None,
1304 pipe_name: None,
1305 };
1306 let task_sql = build_create_merge_into_task_sql(&snowflake_task_context);
1307 let expected = r#"CREATE OR REPLACE TASK "test_db"."test_schema"."test_task_multi_pk"
1308WAREHOUSE = multi_pk_warehouse
1309SCHEDULE = '300 SECONDS'
1310AS
1311BEGIN
1312 LET max_row_id STRING;
1313
1314 SELECT COALESCE(MAX("__row_id"), '0') INTO :max_row_id
1315 FROM "test_db"."test_schema"."cdc_multi_pk";
1316
1317 MERGE INTO "test_db"."test_schema"."target_multi_pk" AS target
1318 USING (
1319 SELECT *
1320 FROM (
1321 SELECT *, ROW_NUMBER() OVER (PARTITION BY "id1", "id2" ORDER BY "__row_id" DESC) AS dedupe_id
1322 FROM "test_db"."test_schema"."cdc_multi_pk"
1323 WHERE "__row_id" <= :max_row_id
1324 ) AS subquery
1325 WHERE dedupe_id = 1
1326 ) AS source
1327 ON target."id1" = source."id1" AND target."id2" = source."id2"
1328 WHEN MATCHED AND source."__op" IN (2, 4) THEN DELETE
1329 WHEN MATCHED AND source."__op" IN (1, 3) THEN UPDATE SET target."id1" = source."id1", target."id2" = source."id2", target."val" = source."val"
1330 WHEN NOT MATCHED AND source."__op" IN (1, 3) THEN INSERT ("id1", "id2", "val") VALUES (source."id1", source."id2", source."val");
1331
1332 DELETE FROM "test_db"."test_schema"."cdc_multi_pk"
1333 WHERE "__row_id" <= :max_row_id;
1334END;"#;
1335 assert_eq!(normalize_sql(&task_sql), normalize_sql(expected));
1336 }
1337}