risingwave_connector/source/adbc_snowflake/
mod.rs1use std::collections::HashMap;
16
17use adbc_core::options::{OptionDatabase, OptionValue};
18use adbc_core::{
19 Connection as AdbcCoreConnection, Database as AdbcCoreDatabase, Statement as AdbcCoreStatement,
20};
21use adbc_snowflake::database::Builder as DatabaseBuilder;
22pub use adbc_snowflake::{Connection, Database, Driver, Statement};
23use anyhow::{Context, anyhow};
24use async_trait::async_trait;
25use futures::StreamExt;
26use futures_async_stream::try_stream;
27use risingwave_common::array::StreamChunk;
28use risingwave_common::array::arrow::{
29 Arrow56FromArrow, arrow_array_56 as arrow, arrow_schema_56 as arrow_schema,
30};
31use risingwave_common::types::JsonbVal;
32use serde::{Deserialize, Serialize};
33
34use crate::error::ConnectorResult;
35use crate::parser::ParserConfig;
36use crate::source::{
37 BoxSourceChunkStream, Column, SourceContextRef, SourceEnumeratorContextRef, SourceProperties,
38 SplitEnumerator, SplitId, SplitMetaData, SplitReader, UnknownFields,
39};
40
41pub const ADBC_SNOWFLAKE_CONNECTOR: &str = "adbc_snowflake";
42
43mod schema;
44
45#[derive(Default)]
46pub struct AdbcSnowflakeArrowConvert;
47
48impl Arrow56FromArrow for AdbcSnowflakeArrowConvert {}
49
50impl AdbcSnowflakeArrowConvert {
51 pub fn chunk_from_record_batch(
52 &self,
53 batch: &arrow::RecordBatch,
54 ) -> Result<risingwave_common::array::DataChunk, risingwave_common::array::ArrayError> {
55 Arrow56FromArrow::from_record_batch(self, batch)
56 }
57
58 pub fn type_from_field(
59 &self,
60 field: &arrow_schema::Field,
61 ) -> Result<risingwave_common::types::DataType, risingwave_common::array::ArrayError> {
62 Arrow56FromArrow::from_field(self, field)
63 }
64}
65
66#[derive(Clone, Debug, Deserialize, with_options::WithOptions)]
68pub struct AdbcSnowflakeProperties {
69 #[serde(rename = "adbc_snowflake.account")]
71 pub account: String,
72
73 #[serde(rename = "adbc_snowflake.username")]
75 pub username: String,
76
77 #[serde(rename = "adbc_snowflake.password")]
79 pub password: Option<String>,
80
81 #[serde(rename = "adbc_snowflake.database")]
83 pub database: String,
84
85 #[serde(rename = "adbc_snowflake.schema")]
87 pub schema: String,
88
89 #[serde(rename = "adbc_snowflake.warehouse")]
91 pub warehouse: String,
92
93 #[serde(rename = "adbc_snowflake.table")]
95 pub table: String,
96
97 #[serde(rename = "adbc_snowflake.role")]
99 pub role: Option<String>,
100
101 #[serde(rename = "adbc_snowflake.host")]
103 pub host: Option<String>,
104
105 #[serde(rename = "adbc_snowflake.port")]
107 pub port: Option<u16>,
108
109 #[serde(rename = "adbc_snowflake.protocol")]
111 pub protocol: Option<String>,
112
113 #[serde(rename = "adbc_snowflake.auth_type")]
116 pub auth_type: Option<String>,
117
118 #[serde(rename = "adbc_snowflake.auth_token")]
120 pub auth_token: Option<String>,
121
122 #[serde(rename = "adbc_snowflake.jwt_private_key_path")]
124 pub jwt_private_key_path: Option<String>,
125
126 #[serde(rename = "adbc_snowflake.jwt_private_key_pkcs8_value")]
127 pub jwt_private_key_pkcs8_value: Option<String>,
128
129 #[serde(rename = "adbc_snowflake.jwt_private_key_pkcs8_password")]
130 pub jwt_private_key_pkcs8_password: Option<String>,
131
132 #[serde(flatten)]
134 pub unknown_fields: HashMap<String, String>,
135}
136
137impl crate::enforce_secret::EnforceSecret for AdbcSnowflakeProperties {
138 const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
139 "adbc_snowflake.password",
140 "adbc_snowflake.auth_token",
141 "adbc_snowflake.jwt_private_key_path",
142 "adbc_snowflake.jwt_private_key_pkcs8_value",
143 "adbc_snowflake.jwt_private_key_pkcs8_password"
144 };
145}
146
147impl UnknownFields for AdbcSnowflakeProperties {
148 fn unknown_fields(&self) -> HashMap<String, String> {
149 self.unknown_fields.clone()
150 }
151}
152
153impl SourceProperties for AdbcSnowflakeProperties {
154 type Split = AdbcSnowflakeSplit;
155 type SplitEnumerator = AdbcSnowflakeSplitEnumerator;
156 type SplitReader = AdbcSnowflakeSplitReader;
157
158 const SOURCE_NAME: &'static str = ADBC_SNOWFLAKE_CONNECTOR;
159}
160
161impl AdbcSnowflakeProperties {
162 pub fn table_ref(&self) -> String {
164 format!(r#""{}"."{}"."{}""#, self.database, self.schema, self.table)
165 }
166
167 pub fn build_select_all_query(&self) -> String {
169 format!("SELECT * FROM {}", self.table_ref())
170 }
171
172 fn build_database_builder(&self) -> ConnectorResult<DatabaseBuilder> {
174 let mut builder = DatabaseBuilder::default()
175 .with_account(&self.account)
176 .with_username(&self.username)
177 .with_database(&self.database)
178 .with_schema(&self.schema)
179 .with_warehouse(&self.warehouse);
180
181 if let Some(ref password) = self.password {
182 builder = builder.with_password(password);
183 }
184
185 builder.other.push((
187 OptionDatabase::Other(
188 "adbc.snowflake.sql.client_option.max_timestamp_precision".to_owned(),
189 ),
190 OptionValue::String("microseconds".to_owned()),
191 ));
192
193 if let Some(ref role) = self.role {
194 builder = builder.with_role(role);
195 }
196
197 if let Some(ref host) = self.host {
198 builder = builder
199 .with_parse_host(host)
200 .context("Failed to parse host")?;
201 }
202
203 if let Some(port) = self.port {
204 builder = builder.with_port(port);
205 }
206
207 if let Some(ref protocol) = self.protocol {
208 builder = builder
209 .with_parse_protocol(protocol)
210 .context("Failed to parse protocol")?;
211 }
212
213 if let Some(ref auth_type) = self.auth_type {
214 builder = builder
215 .with_parse_auth_type(auth_type)
216 .context("Failed to parse auth type")?;
217 }
218
219 if let Some(ref auth_token) = self.auth_token {
220 builder = builder.with_auth_token(auth_token);
221 }
222
223 if let Some(ref jwt_private_key_path) = self.jwt_private_key_path {
224 builder = builder.with_jwt_private_key(jwt_private_key_path.into());
225 }
226
227 if let Some(ref jwt_private_key_pkcs8_value) = self.jwt_private_key_pkcs8_value {
228 builder = builder.with_jwt_private_key_pkcs8_value(jwt_private_key_pkcs8_value.into());
229 }
230
231 if let Some(ref jwt_private_key_pkcs8_password) = self.jwt_private_key_pkcs8_password {
232 builder =
233 builder.with_jwt_private_key_pkcs8_password(jwt_private_key_pkcs8_password.into());
234 }
235
236 Ok(builder)
237 }
238
239 pub fn create_database(&self) -> ConnectorResult<Database> {
242 let mut driver = Driver::try_load().context(
244 "Failed to load ADBC Snowflake driver shared library. \
245 Check the following:\n\
246 1. The ADBC Snowflake driver is installed correctly\n\
247 2. The shared library (libadbc_driver_snowflake.so on Linux, \
248 libadbc_driver_snowflake.dylib on macOS, or \
249 adbc_driver_snowflake.dll on Windows) is in your library path\n\
250 3. Environment variables like LD_LIBRARY_PATH (Linux), \
251 DYLD_LIBRARY_PATH (macOS), or PATH (Windows) are set correctly\n\
252 4. All required dependencies of the ADBC Snowflake driver are installed",
253 )?;
254
255 let builder = self.build_database_builder()?;
256 let database = builder
257 .build(&mut driver)
258 .context("Failed to build database")?;
259 Ok(database)
260 }
261
262 pub fn create_connection(&self, database: &Database) -> ConnectorResult<Connection> {
264 let connection = database
265 .new_connection()
266 .context("Failed to create connection")?;
267 Ok(connection)
268 }
269
270 pub fn create_statement(
272 &self,
273 connection: &mut Connection,
274 query: &str,
275 ) -> ConnectorResult<Statement> {
276 let mut statement = connection
277 .new_statement()
278 .context("Failed to create statement")?;
279 statement
280 .set_sql_query(query)
281 .context("Failed to set SQL query")?;
282 Ok(statement)
283 }
284
285 pub fn execute_query_with_connection(
288 &self,
289 connection: &mut Connection,
290 query: &str,
291 ) -> ConnectorResult<Vec<arrow::RecordBatch>> {
292 let mut statement = connection
293 .new_statement()
294 .context("Failed to create statement")?;
295 statement
296 .set_sql_query(query)
297 .context("Failed to set SQL query")?;
298 let reader = statement.execute().context("Failed to execute query")?;
299
300 let mut batches = Vec::new();
302 for batch_result in reader {
303 let batch = batch_result.context("Failed to read record batch")?;
304 batches.push(batch);
305 }
306 Ok(batches)
307 }
308
309 pub fn execute_query(&self, query: &str) -> ConnectorResult<Vec<arrow::RecordBatch>> {
313 let database = self.create_database()?;
314 let mut connection = self.create_connection(&database)?;
315 self.execute_query_with_connection(&mut connection, query)
316 }
317}
318
319#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Hash)]
322pub struct AdbcSnowflakeSplit {
323 pub split_id: String,
325 pub query: String,
327}
328
329impl SplitMetaData for AdbcSnowflakeSplit {
330 fn id(&self) -> SplitId {
331 self.split_id.clone().into()
332 }
333
334 fn restore_from_json(value: JsonbVal) -> ConnectorResult<Self> {
335 serde_json::from_value(value.take()).map_err(|e| anyhow!(e).into())
336 }
337
338 fn encode_to_json(&self) -> JsonbVal {
339 serde_json::to_value(self.clone()).unwrap().into()
340 }
341
342 fn update_offset(&mut self, _last_seen_offset: String) -> ConnectorResult<()> {
343 Ok(())
345 }
346}
347
348pub struct AdbcSnowflakeSplitEnumerator {
350 properties: AdbcSnowflakeProperties,
351}
352
353#[async_trait]
354impl SplitEnumerator for AdbcSnowflakeSplitEnumerator {
355 type Properties = AdbcSnowflakeProperties;
356 type Split = AdbcSnowflakeSplit;
357
358 async fn new(
359 properties: Self::Properties,
360 _context: SourceEnumeratorContextRef,
361 ) -> ConnectorResult<Self> {
362 Ok(Self { properties })
363 }
364
365 async fn list_splits(&mut self) -> ConnectorResult<Vec<Self::Split>> {
366 let database = self.properties.create_database()?;
369 let mut connection = self.properties.create_connection(&database)?;
370
371 let validation_query = format!(
373 "SELECT * FROM {}.information_schema.tables WHERE table_schema = '{}' LIMIT 1",
374 self.properties.database, self.properties.schema
375 );
376 let mut statement = connection
377 .new_statement()
378 .context("Failed to create statement")?;
379 statement
380 .set_sql_query(&validation_query)
381 .context("Failed to set SQL query")?;
382 let _ = statement
383 .execute()
384 .context("Failed to validate connection")?;
385
386 let split = AdbcSnowflakeSplit {
388 split_id: "0".to_owned(),
389 query: self.properties.build_select_all_query(),
390 };
391 Ok(vec![split])
392 }
393}
394
395pub struct AdbcSnowflakeSplitReader {
397 properties: AdbcSnowflakeProperties,
398 #[allow(dead_code)]
399 splits: Vec<AdbcSnowflakeSplit>,
400 #[allow(dead_code)]
401 parser_config: ParserConfig,
402 #[allow(dead_code)]
403 source_ctx: SourceContextRef,
404}
405
406#[async_trait]
407impl SplitReader for AdbcSnowflakeSplitReader {
408 type Properties = AdbcSnowflakeProperties;
409 type Split = AdbcSnowflakeSplit;
410
411 async fn new(
412 properties: Self::Properties,
413 splits: Vec<Self::Split>,
414 parser_config: ParserConfig,
415 source_ctx: SourceContextRef,
416 _columns: Option<Vec<Column>>,
417 ) -> ConnectorResult<Self> {
418 Ok(Self {
419 properties,
420 splits,
421 parser_config,
422 source_ctx,
423 })
424 }
425
426 fn into_stream(self) -> BoxSourceChunkStream {
427 self.into_chunk_stream().boxed()
428 }
429}
430
431impl AdbcSnowflakeSplitReader {
432 #[try_stream(boxed, ok = StreamChunk, error = crate::error::ConnectorError)]
433 async fn into_chunk_stream(self) {
434 let database = self.properties.create_database()?;
436 let mut connection = self.properties.create_connection(&database)?;
437 let query = self.properties.build_select_all_query();
438 let mut statement = self.properties.create_statement(&mut connection, &query)?;
439
440 let reader = statement.execute().context("Failed to execute query")?;
442
443 let converter = AdbcSnowflakeArrowConvert;
444
445 for batch_result in reader {
447 let batch = batch_result.context("Failed to read record batch")?;
448
449 let data_chunk = converter.chunk_from_record_batch(&batch)?;
451
452 let stream_chunk = StreamChunk::from_parts(
454 vec![risingwave_common::array::Op::Insert; data_chunk.capacity()],
455 data_chunk,
456 );
457
458 yield stream_chunk;
459 }
460 }
461}