risingwave_connector/connector_common/
connection.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::time::Duration;
17
18use anyhow::Context;
19use opendal::Operator;
20use opendal::services::{Azblob, Gcs, S3};
21use phf::{Set, phf_set};
22use rdkafka::ClientConfig;
23use rdkafka::config::RDKafkaLogLevel;
24use rdkafka::consumer::{BaseConsumer, Consumer};
25use risingwave_common::bail;
26use risingwave_common::secret::LocalSecretManager;
27use risingwave_common::util::env_var::env_var_is_true;
28use risingwave_pb::catalog::PbConnection;
29use serde_derive::Deserialize;
30use serde_with::serde_as;
31use tonic::async_trait;
32use url::Url;
33use with_options::WithOptions;
34
35use crate::connector_common::common::DISABLE_DEFAULT_CREDENTIAL;
36use crate::connector_common::{
37    AwsAuthProps, IcebergCommon, KafkaConnectionProps, KafkaPrivateLinkCommon,
38};
39use crate::deserialize_optional_bool_from_string;
40use crate::enforce_secret::EnforceSecret;
41use crate::error::ConnectorResult;
42use crate::schema::schema_registry::Client as ConfluentSchemaRegistryClient;
43use crate::sink::elasticsearch_opensearch::elasticsearch_opensearch_config::ElasticSearchOpenSearchConfig;
44use crate::source::build_connection;
45use crate::source::kafka::{KafkaContextCommon, RwConsumerContext};
46
47pub const SCHEMA_REGISTRY_CONNECTION_TYPE: &str = "schema_registry";
48
49// All XxxConnection structs should implement this trait as well as EnforceSecretOnCloud trait.
50#[async_trait]
51pub trait Connection: Send {
52    async fn validate_connection(&self) -> ConnectorResult<()>;
53}
54
55#[serde_as]
56#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq)]
57#[serde(deny_unknown_fields)]
58pub struct KafkaConnection {
59    #[serde(flatten)]
60    pub inner: KafkaConnectionProps,
61    #[serde(flatten)]
62    pub kafka_private_link_common: KafkaPrivateLinkCommon,
63    #[serde(flatten)]
64    pub aws_auth_props: AwsAuthProps,
65}
66
67impl EnforceSecret for KafkaConnection {
68    fn enforce_secret<'a>(prop_iter: impl Iterator<Item = &'a str>) -> ConnectorResult<()> {
69        for prop in prop_iter {
70            KafkaConnectionProps::enforce_one(prop)?;
71            AwsAuthProps::enforce_one(prop)?;
72        }
73        Ok(())
74    }
75}
76
77pub async fn validate_connection(connection: &PbConnection) -> ConnectorResult<()> {
78    if let Some(ref info) = connection.info {
79        match info {
80            risingwave_pb::catalog::connection::Info::ConnectionParams(cp) => {
81                let options = cp.properties.clone().into_iter().collect();
82                let secret_refs = cp.secret_refs.clone().into_iter().collect();
83                let props_secret_resolved =
84                    LocalSecretManager::global().fill_secrets(options, secret_refs)?;
85                let connection = build_connection(cp.connection_type(), props_secret_resolved)?;
86                connection.validate_connection().await?
87            }
88            risingwave_pb::catalog::connection::Info::PrivateLinkService(_) => unreachable!(),
89        }
90    }
91    Ok(())
92}
93
94#[async_trait]
95impl Connection for KafkaConnection {
96    async fn validate_connection(&self) -> ConnectorResult<()> {
97        let client = self.build_client().await?;
98        // describe cluster here
99        client.fetch_metadata(None, Duration::from_secs(10)).await?;
100        Ok(())
101    }
102}
103
104pub fn read_kafka_log_level() -> Option<RDKafkaLogLevel> {
105    let log_level = std::env::var("RISINGWAVE_KAFKA_LOG_LEVEL").ok()?;
106    match log_level.to_uppercase().as_str() {
107        "DEBUG" => Some(RDKafkaLogLevel::Debug),
108        "INFO" => Some(RDKafkaLogLevel::Info),
109        "WARN" => Some(RDKafkaLogLevel::Warning),
110        "ERROR" => Some(RDKafkaLogLevel::Error),
111        "CRITICAL" => Some(RDKafkaLogLevel::Critical),
112        "EMERG" => Some(RDKafkaLogLevel::Emerg),
113        "ALERT" => Some(RDKafkaLogLevel::Alert),
114        "NOTICE" => Some(RDKafkaLogLevel::Notice),
115        _ => None,
116    }
117}
118
119impl KafkaConnection {
120    async fn build_client(&self) -> ConnectorResult<BaseConsumer<RwConsumerContext>> {
121        let mut config = ClientConfig::new();
122        let bootstrap_servers = &self.inner.brokers;
123        let broker_rewrite_map = self.kafka_private_link_common.broker_rewrite_map.clone();
124        config.set("bootstrap.servers", bootstrap_servers);
125        self.inner.set_security_properties(&mut config);
126
127        // dup with Kafka Enumerator
128        let ctx_common = KafkaContextCommon::new(
129            broker_rewrite_map,
130            None,
131            None,
132            self.aws_auth_props.clone(),
133            self.inner.is_aws_msk_iam(),
134        )
135        .await?;
136        let client_ctx = RwConsumerContext::new(ctx_common);
137
138        if let Some(log_level) = read_kafka_log_level() {
139            config.set_log_level(log_level);
140        }
141        let client: BaseConsumer<RwConsumerContext> =
142            config.create_with_context(client_ctx).await?;
143        if self.inner.is_aws_msk_iam() {
144            #[cfg(not(madsim))]
145            client.poll(Duration::from_secs(10)); // note: this is a blocking call
146            #[cfg(madsim)]
147            client.poll(Duration::from_secs(10)).await;
148        }
149        Ok(client)
150    }
151}
152
153#[serde_as]
154#[derive(Debug, Clone, PartialEq, Eq, Deserialize, WithOptions)]
155#[serde(deny_unknown_fields)]
156pub struct IcebergConnection {
157    #[serde(rename = "catalog.type")]
158    pub catalog_type: Option<String>,
159    #[serde(rename = "s3.region")]
160    pub region: Option<String>,
161    #[serde(rename = "s3.endpoint")]
162    pub endpoint: Option<String>,
163    #[serde(rename = "s3.access.key")]
164    pub access_key: Option<String>,
165    #[serde(rename = "s3.secret.key")]
166    pub secret_key: Option<String>,
167
168    #[serde(rename = "gcs.credential")]
169    pub gcs_credential: Option<String>,
170
171    #[serde(rename = "azblob.account_name")]
172    pub azblob_account_name: Option<String>,
173    #[serde(rename = "azblob.account_key")]
174    pub azblob_account_key: Option<String>,
175    #[serde(rename = "azblob.endpoint_url")]
176    pub azblob_endpoint_url: Option<String>,
177
178    /// Path of iceberg warehouse.
179    #[serde(rename = "warehouse.path")]
180    pub warehouse_path: Option<String>,
181    /// Catalog id, can be omitted for storage catalog or when
182    /// caller's AWS account ID matches glue id
183    #[serde(rename = "glue.id")]
184    pub glue_id: Option<String>,
185    /// Catalog name, default value is risingwave.
186    #[serde(rename = "catalog.name")]
187    pub catalog_name: Option<String>,
188    /// URI of iceberg catalog, only applicable in rest catalog.
189    #[serde(rename = "catalog.uri")]
190    pub catalog_uri: Option<String>,
191    /// Credential for accessing iceberg catalog, only applicable in rest catalog.
192    /// A credential to exchange for a token in the `OAuth2` client credentials flow.
193    #[serde(rename = "catalog.credential")]
194    pub credential: Option<String>,
195    /// token for accessing iceberg catalog, only applicable in rest catalog.
196    /// A Bearer token which will be used for interaction with the server.
197    #[serde(rename = "catalog.token")]
198    pub token: Option<String>,
199    /// `oauth2_server_uri` for accessing iceberg catalog, only applicable in rest catalog.
200    /// Token endpoint URI to fetch token from if the Rest Catalog is not the authorization server.
201    #[serde(rename = "catalog.oauth2_server_uri")]
202    pub oauth2_server_uri: Option<String>,
203    /// scope for accessing iceberg catalog, only applicable in rest catalog.
204    /// Additional scope for `OAuth2`.
205    #[serde(rename = "catalog.scope")]
206    pub scope: Option<String>,
207
208    /// The signing region to use when signing requests to the REST catalog.
209    #[serde(rename = "catalog.rest.signing_region")]
210    pub rest_signing_region: Option<String>,
211
212    /// The signing name to use when signing requests to the REST catalog.
213    #[serde(rename = "catalog.rest.signing_name")]
214    pub rest_signing_name: Option<String>,
215
216    /// Whether to use `SigV4` for signing requests to the REST catalog.
217    #[serde(
218        rename = "catalog.rest.sigv4_enabled",
219        default,
220        deserialize_with = "deserialize_optional_bool_from_string"
221    )]
222    pub rest_sigv4_enabled: Option<bool>,
223
224    #[serde(
225        rename = "s3.path.style.access",
226        default,
227        deserialize_with = "deserialize_optional_bool_from_string"
228    )]
229    pub path_style_access: Option<bool>,
230
231    #[serde(rename = "catalog.jdbc.user")]
232    pub jdbc_user: Option<String>,
233
234    #[serde(rename = "catalog.jdbc.password")]
235    pub jdbc_password: Option<String>,
236
237    /// Enable config load. This parameter set to true will load warehouse credentials from the environment. Only allowed to be used in a self-hosted environment.
238    #[serde(default, deserialize_with = "deserialize_optional_bool_from_string")]
239    pub enable_config_load: Option<bool>,
240
241    /// This is only used by iceberg engine to enable the hosted catalog.
242    #[serde(
243        rename = "hosted_catalog",
244        default,
245        deserialize_with = "deserialize_optional_bool_from_string"
246    )]
247    pub hosted_catalog: Option<bool>,
248
249    /// The http header to be used in the catalog requests.
250    /// Example:
251    /// `catalog.header = "key1=value1;key2=value2;key3=value3"`
252    /// explain the format of the header:
253    /// - Each header is a key-value pair, separated by an '='.
254    /// - Multiple headers can be specified, separated by a ';'.
255    #[serde(rename = "catalog.header")]
256    pub header: Option<String>,
257}
258
259impl EnforceSecret for IcebergConnection {
260    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
261        "s3.access.key",
262        "s3.secret.key",
263        "gcs.credential",
264        "catalog.token",
265    };
266}
267
268#[async_trait]
269impl Connection for IcebergConnection {
270    async fn validate_connection(&self) -> ConnectorResult<()> {
271        let info = match &self.warehouse_path {
272            Some(warehouse_path) => {
273                let is_s3_tables = warehouse_path.starts_with("arn:aws:s3tables");
274                let url = Url::parse(warehouse_path);
275                if (url.is_err() || is_s3_tables)
276                    && matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust"))
277                {
278                    // If the warehouse path is not a valid URL, it could be a warehouse name in rest catalog,
279                    // Or it could be a s3tables path, which is not a valid URL but a valid warehouse path,
280                    // so we allow it to pass here.
281                    None
282                } else {
283                    let url =
284                        url.with_context(|| format!("Invalid warehouse path: {}", warehouse_path))?;
285                    let bucket = url
286                        .host_str()
287                        .with_context(|| {
288                            format!("Invalid s3 path: {}, bucket is missing", warehouse_path)
289                        })?
290                        .to_owned();
291                    let root = url.path().trim_start_matches('/').to_owned();
292                    Some((url.scheme().to_owned(), bucket, root))
293                }
294            }
295            None => {
296                if matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust")) {
297                    None
298                } else {
299                    bail!("`warehouse.path` must be set");
300                }
301            }
302        };
303
304        // Test warehouse
305        if let Some((scheme, bucket, root)) = info {
306            match scheme.as_str() {
307                "s3" | "s3a" => {
308                    let mut builder = S3::default();
309                    if let Some(region) = &self.region {
310                        builder = builder.region(region);
311                    }
312                    if let Some(endpoint) = &self.endpoint {
313                        builder = builder.endpoint(endpoint);
314                    }
315                    if let Some(access_key) = &self.access_key {
316                        builder = builder.access_key_id(access_key);
317                    }
318                    if let Some(secret_key) = &self.secret_key {
319                        builder = builder.secret_access_key(secret_key);
320                    }
321                    builder = builder.root(root.as_str()).bucket(bucket.as_str());
322                    let op = Operator::new(builder)?.finish();
323                    op.check().await?;
324                }
325                "gs" | "gcs" => {
326                    let mut builder = Gcs::default();
327                    if let Some(credential) = &self.gcs_credential {
328                        builder = builder.credential(credential);
329                    }
330                    builder = builder.root(root.as_str()).bucket(bucket.as_str());
331                    let op = Operator::new(builder)?.finish();
332                    op.check().await?;
333                }
334                "azblob" => {
335                    let mut builder = Azblob::default();
336                    if let Some(account_name) = &self.azblob_account_name {
337                        builder = builder.account_name(account_name);
338                    }
339                    if let Some(azblob_account_key) = &self.azblob_account_key {
340                        builder = builder.account_key(azblob_account_key);
341                    }
342                    if let Some(azblob_endpoint_url) = &self.azblob_endpoint_url {
343                        builder = builder.endpoint(azblob_endpoint_url);
344                    }
345                    builder = builder.root(root.as_str()).container(bucket.as_str());
346                    let op = Operator::new(builder)?.finish();
347                    op.check().await?;
348                }
349                _ => {
350                    bail!("Unsupported scheme: {}", scheme);
351                }
352            }
353        }
354
355        if env_var_is_true(DISABLE_DEFAULT_CREDENTIAL)
356            && matches!(self.enable_config_load, Some(true))
357        {
358            bail!("`enable_config_load` can't be enabled in this environment");
359        }
360
361        if self.hosted_catalog.unwrap_or(false) {
362            // If `hosted_catalog` is set, we don't need to test the catalog, but just ensure no catalog fields are set.
363            if self.catalog_type.is_some() {
364                bail!("`catalog.type` must not be set when `hosted_catalog` is set");
365            }
366            if self.catalog_uri.is_some() {
367                bail!("`catalog.uri` must not be set when `hosted_catalog` is set");
368            }
369            if self.catalog_name.is_some() {
370                bail!("`catalog.name` must not be set when `hosted_catalog` is set");
371            }
372            if self.jdbc_user.is_some() {
373                bail!("`catalog.jdbc.user` must not be set when `hosted_catalog` is set");
374            }
375            if self.jdbc_password.is_some() {
376                bail!("`catalog.jdbc.password` must not be set when `hosted_catalog` is set");
377            }
378            return Ok(());
379        }
380
381        if self.catalog_type.is_none() {
382            bail!("`catalog.type` must be set");
383        }
384
385        // Test catalog
386        let iceberg_common = IcebergCommon {
387            catalog_type: self.catalog_type.clone(),
388            region: self.region.clone(),
389            endpoint: self.endpoint.clone(),
390            access_key: self.access_key.clone(),
391            secret_key: self.secret_key.clone(),
392            gcs_credential: self.gcs_credential.clone(),
393            azblob_account_name: self.azblob_account_name.clone(),
394            azblob_account_key: self.azblob_account_key.clone(),
395            azblob_endpoint_url: self.azblob_endpoint_url.clone(),
396            warehouse_path: self.warehouse_path.clone(),
397            glue_id: self.glue_id.clone(),
398            catalog_name: self.catalog_name.clone(),
399            catalog_uri: self.catalog_uri.clone(),
400            credential: self.credential.clone(),
401
402            token: self.token.clone(),
403            oauth2_server_uri: self.oauth2_server_uri.clone(),
404            scope: self.scope.clone(),
405            rest_signing_region: self.rest_signing_region.clone(),
406            rest_signing_name: self.rest_signing_name.clone(),
407            rest_sigv4_enabled: self.rest_sigv4_enabled,
408            path_style_access: self.path_style_access,
409            database_name: Some("test_database".to_owned()),
410            table_name: "test_table".to_owned(),
411            enable_config_load: self.enable_config_load,
412            hosted_catalog: self.hosted_catalog,
413            header: self.header.clone(),
414        };
415
416        let mut java_map = HashMap::new();
417        if let Some(jdbc_user) = &self.jdbc_user {
418            java_map.insert("jdbc.user".to_owned(), jdbc_user.to_owned());
419        }
420        if let Some(jdbc_password) = &self.jdbc_password {
421            java_map.insert("jdbc.password".to_owned(), jdbc_password.to_owned());
422        }
423        let catalog = iceberg_common.create_catalog(&java_map).await?;
424        // test catalog by `table_exists` api
425        catalog
426            .table_exists(&iceberg_common.full_table_name()?)
427            .await?;
428        Ok(())
429    }
430}
431
432#[serde_as]
433#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
434#[serde(deny_unknown_fields)]
435pub struct ConfluentSchemaRegistryConnection {
436    #[serde(rename = "schema.registry")]
437    pub url: String,
438    // ref `SchemaRegistryAuth`
439    #[serde(rename = "schema.registry.username")]
440    pub username: Option<String>,
441    #[serde(rename = "schema.registry.password")]
442    pub password: Option<String>,
443}
444
445#[async_trait]
446impl Connection for ConfluentSchemaRegistryConnection {
447    async fn validate_connection(&self) -> ConnectorResult<()> {
448        // GET /config to validate the connection
449        let client = ConfluentSchemaRegistryClient::try_from(self)?;
450        client.validate_connection().await?;
451        Ok(())
452    }
453}
454
455impl EnforceSecret for ConfluentSchemaRegistryConnection {
456    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
457        "schema.registry.password",
458    };
459}
460
461#[derive(Debug, Clone, Deserialize, PartialEq, Hash, Eq)]
462pub struct ElasticsearchConnection(pub BTreeMap<String, String>);
463
464#[async_trait]
465impl Connection for ElasticsearchConnection {
466    async fn validate_connection(&self) -> ConnectorResult<()> {
467        const CONNECTOR: &str = "elasticsearch";
468
469        let config = ElasticSearchOpenSearchConfig::try_from(self)?;
470        let client = config.build_client(CONNECTOR)?;
471        client.ping().await?;
472        Ok(())
473    }
474}
475
476impl EnforceSecret for ElasticsearchConnection {
477    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
478        "elasticsearch.password",
479    };
480}