risingwave_connector/source/adbc_snowflake/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use adbc_core::options::{OptionDatabase, OptionValue};
18use adbc_core::{
19    Connection as AdbcCoreConnection, Database as AdbcCoreDatabase, Statement as AdbcCoreStatement,
20};
21use adbc_snowflake::database::Builder as DatabaseBuilder;
22pub use adbc_snowflake::{Connection, Database, Driver, Statement};
23use anyhow::{Context, anyhow};
24use async_trait::async_trait;
25use futures::StreamExt;
26use futures_async_stream::try_stream;
27use risingwave_common::array::StreamChunk;
28use risingwave_common::array::arrow::{
29    Arrow56FromArrow, arrow_array_56 as arrow, arrow_schema_56 as arrow_schema,
30};
31use risingwave_common::types::JsonbVal;
32use serde::{Deserialize, Serialize};
33
34use crate::error::ConnectorResult;
35use crate::parser::ParserConfig;
36use crate::source::{
37    BoxSourceChunkStream, Column, SourceContextRef, SourceEnumeratorContextRef, SourceProperties,
38    SplitEnumerator, SplitId, SplitMetaData, SplitReader, UnknownFields,
39};
40
41pub const ADBC_SNOWFLAKE_CONNECTOR: &str = "adbc_snowflake";
42
43mod schema;
44
45#[derive(Default)]
46pub struct AdbcSnowflakeArrowConvert;
47
48impl Arrow56FromArrow for AdbcSnowflakeArrowConvert {}
49
50impl AdbcSnowflakeArrowConvert {
51    pub fn chunk_from_record_batch(
52        &self,
53        batch: &arrow::RecordBatch,
54    ) -> Result<risingwave_common::array::DataChunk, risingwave_common::array::ArrayError> {
55        Arrow56FromArrow::from_record_batch(self, batch)
56    }
57
58    pub fn type_from_field(
59        &self,
60        field: &arrow_schema::Field,
61    ) -> Result<risingwave_common::types::DataType, risingwave_common::array::ArrayError> {
62        Arrow56FromArrow::from_field(self, field)
63    }
64}
65
66/// Properties for ADBC Snowflake source connector.
67#[derive(Clone, Debug, Deserialize, with_options::WithOptions)]
68pub struct AdbcSnowflakeProperties {
69    /// The Snowflake account identifier (e.g., "myaccount" or "myaccount.us-east-1").
70    #[serde(rename = "adbc_snowflake.account")]
71    pub account: String,
72
73    /// The username for authentication.
74    #[serde(rename = "adbc_snowflake.username")]
75    pub username: String,
76
77    /// The password for authentication.
78    #[serde(rename = "adbc_snowflake.password")]
79    pub password: Option<String>,
80
81    /// The name of the database to use.
82    #[serde(rename = "adbc_snowflake.database")]
83    pub database: String,
84
85    /// The name of the schema to use.
86    #[serde(rename = "adbc_snowflake.schema")]
87    pub schema: String,
88
89    /// The name of the warehouse to use.
90    #[serde(rename = "adbc_snowflake.warehouse")]
91    pub warehouse: String,
92
93    /// The table name to load (full table).
94    #[serde(rename = "adbc_snowflake.table")]
95    pub table: String,
96
97    /// The role to use (optional).
98    #[serde(rename = "adbc_snowflake.role")]
99    pub role: Option<String>,
100
101    /// The host to connect to (optional, defaults to Snowflake cloud).
102    #[serde(rename = "adbc_snowflake.host")]
103    pub host: Option<String>,
104
105    /// The port to connect to (optional).
106    #[serde(rename = "adbc_snowflake.port")]
107    pub port: Option<u16>,
108
109    /// The protocol to use (optional, defaults to "https").
110    #[serde(rename = "adbc_snowflake.protocol")]
111    pub protocol: Option<String>,
112
113    /// The authentication type (optional, defaults to "`auth_snowflake`").
114    /// Possible values: `auth_snowflake`, `auth_oauth`, `auth_ext_browser`, `auth_okta`, `auth_jwt`, `auth_mfa`, `auth_pat`, `auth_wif`
115    #[serde(rename = "adbc_snowflake.auth_type")]
116    pub auth_type: Option<String>,
117
118    /// `OAuth` token for authentication (when using `auth_oauth`).
119    #[serde(rename = "adbc_snowflake.auth_token")]
120    pub auth_token: Option<String>,
121
122    /// JWT private key file path (when using `auth_jwt`).
123    #[serde(rename = "adbc_snowflake.jwt_private_key_path")]
124    pub jwt_private_key_path: Option<String>,
125
126    #[serde(rename = "adbc_snowflake.jwt_private_key_pkcs8_value")]
127    pub jwt_private_key_pkcs8_value: Option<String>,
128
129    #[serde(rename = "adbc_snowflake.jwt_private_key_pkcs8_password")]
130    pub jwt_private_key_pkcs8_password: Option<String>,
131
132    /// Unknown fields for forward compatibility.
133    #[serde(flatten)]
134    pub unknown_fields: HashMap<String, String>,
135}
136
137impl crate::enforce_secret::EnforceSecret for AdbcSnowflakeProperties {
138    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
139        "adbc_snowflake.password",
140        "adbc_snowflake.auth_token",
141        "adbc_snowflake.jwt_private_key_path",
142        "adbc_snowflake.jwt_private_key_pkcs8_value",
143        "adbc_snowflake.jwt_private_key_pkcs8_password"
144    };
145}
146
147impl UnknownFields for AdbcSnowflakeProperties {
148    fn unknown_fields(&self) -> HashMap<String, String> {
149        self.unknown_fields.clone()
150    }
151}
152
153impl SourceProperties for AdbcSnowflakeProperties {
154    type Split = AdbcSnowflakeSplit;
155    type SplitEnumerator = AdbcSnowflakeSplitEnumerator;
156    type SplitReader = AdbcSnowflakeSplitReader;
157
158    const SOURCE_NAME: &'static str = ADBC_SNOWFLAKE_CONNECTOR;
159}
160
161impl AdbcSnowflakeProperties {
162    /// Qualified table reference used in all generated queries.
163    pub fn table_ref(&self) -> String {
164        format!(r#""{}"."{}"."{}""#, self.database, self.schema, self.table)
165    }
166
167    /// Default full-table select.
168    pub fn build_select_all_query(&self) -> String {
169        format!("SELECT * FROM {}", self.table_ref())
170    }
171
172    /// Build a database builder from properties.
173    fn build_database_builder(&self) -> ConnectorResult<DatabaseBuilder> {
174        let mut builder = DatabaseBuilder::default()
175            .with_account(&self.account)
176            .with_username(&self.username)
177            .with_database(&self.database)
178            .with_schema(&self.schema)
179            .with_warehouse(&self.warehouse);
180
181        if let Some(ref password) = self.password {
182            builder = builder.with_password(password);
183        }
184
185        // Set the max timestamp precision to microseconds, as RisingWave supports at most microsecond precision.
186        builder.other.push((
187            OptionDatabase::Other(
188                "adbc.snowflake.sql.client_option.max_timestamp_precision".to_owned(),
189            ),
190            OptionValue::String("microseconds".to_owned()),
191        ));
192
193        if let Some(ref role) = self.role {
194            builder = builder.with_role(role);
195        }
196
197        if let Some(ref host) = self.host {
198            builder = builder
199                .with_parse_host(host)
200                .context("Failed to parse host")?;
201        }
202
203        if let Some(port) = self.port {
204            builder = builder.with_port(port);
205        }
206
207        if let Some(ref protocol) = self.protocol {
208            builder = builder
209                .with_parse_protocol(protocol)
210                .context("Failed to parse protocol")?;
211        }
212
213        if let Some(ref auth_type) = self.auth_type {
214            builder = builder
215                .with_parse_auth_type(auth_type)
216                .context("Failed to parse auth type")?;
217        }
218
219        if let Some(ref auth_token) = self.auth_token {
220            builder = builder.with_auth_token(auth_token);
221        }
222
223        if let Some(ref jwt_private_key_path) = self.jwt_private_key_path {
224            builder = builder.with_jwt_private_key(jwt_private_key_path.into());
225        }
226
227        if let Some(ref jwt_private_key_pkcs8_value) = self.jwt_private_key_pkcs8_value {
228            builder = builder.with_jwt_private_key_pkcs8_value(jwt_private_key_pkcs8_value.into());
229        }
230
231        if let Some(ref jwt_private_key_pkcs8_password) = self.jwt_private_key_pkcs8_password {
232            builder =
233                builder.with_jwt_private_key_pkcs8_password(jwt_private_key_pkcs8_password.into());
234        }
235
236        Ok(builder)
237    }
238
239    /// Create an ADBC Snowflake database connection.
240    /// This validates that the driver library is available before attempting connection.
241    pub fn create_database(&self) -> ConnectorResult<Database> {
242        // Validate driver availability and load the driver
243        let mut driver = Driver::try_load().context(
244            "Failed to load ADBC Snowflake driver shared library. \
245                Check the following:\n\
246                1. The ADBC Snowflake driver is installed correctly\n\
247                2. The shared library (libadbc_driver_snowflake.so on Linux, \
248                   libadbc_driver_snowflake.dylib on macOS, or \
249                   adbc_driver_snowflake.dll on Windows) is in your library path\n\
250                3. Environment variables like LD_LIBRARY_PATH (Linux), \
251                   DYLD_LIBRARY_PATH (macOS), or PATH (Windows) are set correctly\n\
252                4. All required dependencies of the ADBC Snowflake driver are installed",
253        )?;
254
255        let builder = self.build_database_builder()?;
256        let database = builder
257            .build(&mut driver)
258            .context("Failed to build database")?;
259        Ok(database)
260    }
261
262    /// Create a connection from the database.
263    pub fn create_connection(&self, database: &Database) -> ConnectorResult<Connection> {
264        let connection = database
265            .new_connection()
266            .context("Failed to create connection")?;
267        Ok(connection)
268    }
269
270    /// Create a statement from the connection and set the SQL query.
271    pub fn create_statement(
272        &self,
273        connection: &mut Connection,
274        query: &str,
275    ) -> ConnectorResult<Statement> {
276        let mut statement = connection
277            .new_statement()
278            .context("Failed to create statement")?;
279        statement
280            .set_sql_query(query)
281            .context("Failed to set SQL query")?;
282        Ok(statement)
283    }
284
285    /// Execute a custom query using a provided connection.
286    /// This is useful for metadata queries needed for split generation while reusing connections.
287    pub fn execute_query_with_connection(
288        &self,
289        connection: &mut Connection,
290        query: &str,
291    ) -> ConnectorResult<Vec<arrow::RecordBatch>> {
292        let mut statement = connection
293            .new_statement()
294            .context("Failed to create statement")?;
295        statement
296            .set_sql_query(query)
297            .context("Failed to set SQL query")?;
298        let reader = statement.execute().context("Failed to execute query")?;
299
300        // Collect all batches into a vector
301        let mut batches = Vec::new();
302        for batch_result in reader {
303            let batch = batch_result.context("Failed to read record batch")?;
304            batches.push(batch);
305        }
306        Ok(batches)
307    }
308
309    /// Execute a custom query and return the results as a vector of Arrow record batches.
310    /// This is useful for metadata queries needed for split generation.
311    /// Creates a new connection for each query - use `execute_query_with_connection` for better performance.
312    pub fn execute_query(&self, query: &str) -> ConnectorResult<Vec<arrow::RecordBatch>> {
313        let database = self.create_database()?;
314        let mut connection = self.create_connection(&database)?;
315        self.execute_query_with_connection(&mut connection, query)
316    }
317}
318
319/// Split for ADBC Snowflake source.
320/// Since Snowflake queries are executed as a whole, we use a single split with the query as the identifier.
321#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Hash)]
322pub struct AdbcSnowflakeSplit {
323    /// The split identifier (typically based on the query).
324    pub split_id: String,
325    /// The SQL query to execute.
326    pub query: String,
327}
328
329impl SplitMetaData for AdbcSnowflakeSplit {
330    fn id(&self) -> SplitId {
331        self.split_id.clone().into()
332    }
333
334    fn restore_from_json(value: JsonbVal) -> ConnectorResult<Self> {
335        serde_json::from_value(value.take()).map_err(|e| anyhow!(e).into())
336    }
337
338    fn encode_to_json(&self) -> JsonbVal {
339        serde_json::to_value(self.clone()).unwrap().into()
340    }
341
342    fn update_offset(&mut self, _last_seen_offset: String) -> ConnectorResult<()> {
343        // ADBC Snowflake doesn't have offset-based reading for now
344        Ok(())
345    }
346}
347
348/// Split enumerator for ADBC Snowflake source.
349pub struct AdbcSnowflakeSplitEnumerator {
350    properties: AdbcSnowflakeProperties,
351}
352
353#[async_trait]
354impl SplitEnumerator for AdbcSnowflakeSplitEnumerator {
355    type Properties = AdbcSnowflakeProperties;
356    type Split = AdbcSnowflakeSplit;
357
358    async fn new(
359        properties: Self::Properties,
360        _context: SourceEnumeratorContextRef,
361    ) -> ConnectorResult<Self> {
362        Ok(Self { properties })
363    }
364
365    async fn list_splits(&mut self) -> ConnectorResult<Vec<Self::Split>> {
366        // Validate connection and access by establishing a connection and preparing the query.
367        // This ensures credentials are correct and the query is valid before returning the split.
368        let database = self.properties.create_database()?;
369        let mut connection = self.properties.create_connection(&database)?;
370
371        // Validate connection and access by running a simple query against the target database/schema
372        let validation_query = format!(
373            "SELECT * FROM {}.information_schema.tables WHERE table_schema = '{}' LIMIT 1",
374            self.properties.database, self.properties.schema
375        );
376        let mut statement = connection
377            .new_statement()
378            .context("Failed to create statement")?;
379        statement
380            .set_sql_query(&validation_query)
381            .context("Failed to set SQL query")?;
382        let _ = statement
383            .execute()
384            .context("Failed to validate connection")?;
385
386        // Connection and query are valid, return the split
387        let split = AdbcSnowflakeSplit {
388            split_id: "0".to_owned(),
389            query: self.properties.build_select_all_query(),
390        };
391        Ok(vec![split])
392    }
393}
394
395/// Split reader for ADBC Snowflake source.
396pub struct AdbcSnowflakeSplitReader {
397    properties: AdbcSnowflakeProperties,
398    #[allow(dead_code)]
399    splits: Vec<AdbcSnowflakeSplit>,
400    #[allow(dead_code)]
401    parser_config: ParserConfig,
402    #[allow(dead_code)]
403    source_ctx: SourceContextRef,
404}
405
406#[async_trait]
407impl SplitReader for AdbcSnowflakeSplitReader {
408    type Properties = AdbcSnowflakeProperties;
409    type Split = AdbcSnowflakeSplit;
410
411    async fn new(
412        properties: Self::Properties,
413        splits: Vec<Self::Split>,
414        parser_config: ParserConfig,
415        source_ctx: SourceContextRef,
416        _columns: Option<Vec<Column>>,
417    ) -> ConnectorResult<Self> {
418        Ok(Self {
419            properties,
420            splits,
421            parser_config,
422            source_ctx,
423        })
424    }
425
426    fn into_stream(self) -> BoxSourceChunkStream {
427        self.into_chunk_stream().boxed()
428    }
429}
430
431impl AdbcSnowflakeSplitReader {
432    #[try_stream(boxed, ok = StreamChunk, error = crate::error::ConnectorError)]
433    async fn into_chunk_stream(self) {
434        // Execute the query and read the results as Arrow record batches
435        let database = self.properties.create_database()?;
436        let mut connection = self.properties.create_connection(&database)?;
437        let query = self.properties.build_select_all_query();
438        let mut statement = self.properties.create_statement(&mut connection, &query)?;
439
440        // Execute the query and get a record batch reader
441        let reader = statement.execute().context("Failed to execute query")?;
442
443        let converter = AdbcSnowflakeArrowConvert;
444
445        // Iterate over the record batches and convert them to StreamChunks
446        for batch_result in reader {
447            let batch = batch_result.context("Failed to read record batch")?;
448
449            // Convert Arrow RecordBatch to RisingWave DataChunk using the converter
450            let data_chunk = converter.chunk_from_record_batch(&batch)?;
451
452            // Convert DataChunk to StreamChunk (all inserts)
453            let stream_chunk = StreamChunk::from_parts(
454                vec![risingwave_common::array::Op::Insert; data_chunk.capacity()],
455                data_chunk,
456            );
457
458            yield stream_chunk;
459        }
460    }
461}