risingwave_connector/connector_common/
connection.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::time::Duration;
17
18use anyhow::Context;
19use opendal::Operator;
20use opendal::services::{Azblob, Gcs, S3};
21use phf::{Set, phf_set};
22use rdkafka::ClientConfig;
23use rdkafka::consumer::{BaseConsumer, Consumer};
24use risingwave_common::bail;
25use risingwave_common::secret::LocalSecretManager;
26use risingwave_pb::catalog::PbConnection;
27use serde_derive::Deserialize;
28use serde_with::serde_as;
29use tonic::async_trait;
30use url::Url;
31use with_options::WithOptions;
32
33use crate::connector_common::{
34    AwsAuthProps, IcebergCommon, KafkaConnectionProps, KafkaPrivateLinkCommon,
35};
36use crate::deserialize_optional_bool_from_string;
37use crate::enforce_secret::EnforceSecret;
38use crate::error::ConnectorResult;
39use crate::schema::schema_registry::Client as ConfluentSchemaRegistryClient;
40use crate::sink::elasticsearch_opensearch::elasticsearch_opensearch_config::ElasticSearchOpenSearchConfig;
41use crate::source::build_connection;
42use crate::source::kafka::{KafkaContextCommon, RwConsumerContext};
43
44pub const SCHEMA_REGISTRY_CONNECTION_TYPE: &str = "schema_registry";
45
46// All XxxConnection structs should implement this trait as well as EnforceSecretOnCloud trait.
47#[async_trait]
48pub trait Connection: Send {
49    async fn validate_connection(&self) -> ConnectorResult<()>;
50}
51
52#[serde_as]
53#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq)]
54#[serde(deny_unknown_fields)]
55pub struct KafkaConnection {
56    #[serde(flatten)]
57    pub inner: KafkaConnectionProps,
58    #[serde(flatten)]
59    pub kafka_private_link_common: KafkaPrivateLinkCommon,
60    #[serde(flatten)]
61    pub aws_auth_props: AwsAuthProps,
62}
63
64impl EnforceSecret for KafkaConnection {
65    fn enforce_secret<'a>(prop_iter: impl Iterator<Item = &'a str>) -> ConnectorResult<()> {
66        for prop in prop_iter {
67            KafkaConnectionProps::enforce_one(prop)?;
68            AwsAuthProps::enforce_one(prop)?;
69        }
70        Ok(())
71    }
72}
73
74pub async fn validate_connection(connection: &PbConnection) -> ConnectorResult<()> {
75    if let Some(ref info) = connection.info {
76        match info {
77            risingwave_pb::catalog::connection::Info::ConnectionParams(cp) => {
78                let options = cp.properties.clone().into_iter().collect();
79                let secret_refs = cp.secret_refs.clone().into_iter().collect();
80                let props_secret_resolved =
81                    LocalSecretManager::global().fill_secrets(options, secret_refs)?;
82                let connection = build_connection(cp.connection_type(), props_secret_resolved)?;
83                connection.validate_connection().await?
84            }
85            risingwave_pb::catalog::connection::Info::PrivateLinkService(_) => unreachable!(),
86        }
87    }
88    Ok(())
89}
90
91#[async_trait]
92impl Connection for KafkaConnection {
93    async fn validate_connection(&self) -> ConnectorResult<()> {
94        let client = self.build_client().await?;
95        // describe cluster here
96        client.fetch_metadata(None, Duration::from_secs(10)).await?;
97        Ok(())
98    }
99}
100
101impl KafkaConnection {
102    async fn build_client(&self) -> ConnectorResult<BaseConsumer<RwConsumerContext>> {
103        let mut config = ClientConfig::new();
104        let bootstrap_servers = &self.inner.brokers;
105        let broker_rewrite_map = self.kafka_private_link_common.broker_rewrite_map.clone();
106        config.set("bootstrap.servers", bootstrap_servers);
107        self.inner.set_security_properties(&mut config);
108
109        // dup with Kafka Enumerator
110        let ctx_common = KafkaContextCommon::new(
111            broker_rewrite_map,
112            None,
113            None,
114            self.aws_auth_props.clone(),
115            self.inner.is_aws_msk_iam(),
116        )
117        .await?;
118        let client_ctx = RwConsumerContext::new(ctx_common);
119        let client: BaseConsumer<RwConsumerContext> =
120            config.create_with_context(client_ctx).await?;
121        if self.inner.is_aws_msk_iam() {
122            #[cfg(not(madsim))]
123            client.poll(Duration::from_secs(10)); // note: this is a blocking call
124            #[cfg(madsim)]
125            client.poll(Duration::from_secs(10)).await;
126        }
127        Ok(client)
128    }
129}
130
131#[serde_as]
132#[derive(Debug, Clone, PartialEq, Eq, Deserialize, WithOptions)]
133#[serde(deny_unknown_fields)]
134pub struct IcebergConnection {
135    #[serde(rename = "catalog.type")]
136    pub catalog_type: Option<String>,
137    #[serde(rename = "s3.region")]
138    pub region: Option<String>,
139    #[serde(rename = "s3.endpoint")]
140    pub endpoint: Option<String>,
141    #[serde(rename = "s3.access.key")]
142    pub access_key: Option<String>,
143    #[serde(rename = "s3.secret.key")]
144    pub secret_key: Option<String>,
145
146    #[serde(rename = "gcs.credential")]
147    pub gcs_credential: Option<String>,
148
149    #[serde(rename = "azblob.account_name")]
150    pub azblob_account_name: Option<String>,
151    #[serde(rename = "azblob.account_key")]
152    pub azblob_account_key: Option<String>,
153    #[serde(rename = "azblob.endpoint_url")]
154    pub azblob_endpoint_url: Option<String>,
155
156    /// Path of iceberg warehouse.
157    #[serde(rename = "warehouse.path")]
158    pub warehouse_path: Option<String>,
159    /// Catalog id, can be omitted for storage catalog or when
160    /// caller's AWS account ID matches glue id
161    #[serde(rename = "glue.id")]
162    pub glue_id: Option<String>,
163    /// Catalog name, default value is risingwave.
164    #[serde(rename = "catalog.name")]
165    pub catalog_name: Option<String>,
166    /// URI of iceberg catalog, only applicable in rest catalog.
167    #[serde(rename = "catalog.uri")]
168    pub catalog_uri: Option<String>,
169    /// Credential for accessing iceberg catalog, only applicable in rest catalog.
170    /// A credential to exchange for a token in the OAuth2 client credentials flow.
171    #[serde(rename = "catalog.credential")]
172    pub credential: Option<String>,
173    /// token for accessing iceberg catalog, only applicable in rest catalog.
174    /// A Bearer token which will be used for interaction with the server.
175    #[serde(rename = "catalog.token")]
176    pub token: Option<String>,
177    /// `oauth2_server_uri` for accessing iceberg catalog, only applicable in rest catalog.
178    /// Token endpoint URI to fetch token from if the Rest Catalog is not the authorization server.
179    #[serde(rename = "catalog.oauth2_server_uri")]
180    pub oauth2_server_uri: Option<String>,
181    /// scope for accessing iceberg catalog, only applicable in rest catalog.
182    /// Additional scope for OAuth2.
183    #[serde(rename = "catalog.scope")]
184    pub scope: Option<String>,
185
186    /// The signing region to use when signing requests to the REST catalog.
187    #[serde(rename = "catalog.rest.signing_region")]
188    pub rest_signing_region: Option<String>,
189
190    /// The signing name to use when signing requests to the REST catalog.
191    #[serde(rename = "catalog.rest.signing_name")]
192    pub rest_signing_name: Option<String>,
193
194    /// Whether to use SigV4 for signing requests to the REST catalog.
195    #[serde(
196        rename = "catalog.rest.sigv4_enabled",
197        default,
198        deserialize_with = "deserialize_optional_bool_from_string"
199    )]
200    pub rest_sigv4_enabled: Option<bool>,
201
202    #[serde(
203        rename = "s3.path.style.access",
204        default,
205        deserialize_with = "deserialize_optional_bool_from_string"
206    )]
207    pub path_style_access: Option<bool>,
208
209    #[serde(rename = "catalog.jdbc.user")]
210    pub jdbc_user: Option<String>,
211
212    #[serde(rename = "catalog.jdbc.password")]
213    pub jdbc_password: Option<String>,
214
215    /// This is only used by iceberg engine to enable the hosted catalog.
216    #[serde(
217        rename = "hosted_catalog",
218        default,
219        deserialize_with = "deserialize_optional_bool_from_string"
220    )]
221    pub hosted_catalog: Option<bool>,
222}
223
224impl EnforceSecret for IcebergConnection {
225    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
226        "s3.access.key",
227        "s3.secret.key",
228        "gcs.credential",
229        "catalog.token",
230    };
231}
232
233#[async_trait]
234impl Connection for IcebergConnection {
235    async fn validate_connection(&self) -> ConnectorResult<()> {
236        let info = match &self.warehouse_path {
237            Some(warehouse_path) => {
238                let is_s3_tables = warehouse_path.starts_with("arn:aws:s3tables");
239                let url = Url::parse(warehouse_path);
240                if (url.is_err() || is_s3_tables)
241                    && matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust"))
242                {
243                    // If the warehouse path is not a valid URL, it could be a warehouse name in rest catalog,
244                    // Or it could be a s3tables path, which is not a valid URL but a valid warehouse path,
245                    // so we allow it to pass here.
246                    None
247                } else {
248                    let url =
249                        url.with_context(|| format!("Invalid warehouse path: {}", warehouse_path))?;
250                    let bucket = url
251                        .host_str()
252                        .with_context(|| {
253                            format!("Invalid s3 path: {}, bucket is missing", warehouse_path)
254                        })?
255                        .to_owned();
256                    let root = url.path().trim_start_matches('/').to_owned();
257                    Some((url.scheme().to_owned(), bucket, root))
258                }
259            }
260            None => {
261                if matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust")) {
262                    None
263                } else {
264                    bail!("`warehouse.path` must be set");
265                }
266            }
267        };
268
269        // Test warehouse
270        if let Some((scheme, bucket, root)) = info {
271            match scheme.as_str() {
272                "s3" | "s3a" => {
273                    let mut builder = S3::default();
274                    if let Some(region) = &self.region {
275                        builder = builder.region(region);
276                    }
277                    if let Some(endpoint) = &self.endpoint {
278                        builder = builder.endpoint(endpoint);
279                    }
280                    if let Some(access_key) = &self.access_key {
281                        builder = builder.access_key_id(access_key);
282                    }
283                    if let Some(secret_key) = &self.secret_key {
284                        builder = builder.secret_access_key(secret_key);
285                    }
286                    builder = builder.root(root.as_str()).bucket(bucket.as_str());
287                    let op = Operator::new(builder)?.finish();
288                    op.check().await?;
289                }
290                "gs" | "gcs" => {
291                    let mut builder = Gcs::default();
292                    if let Some(credential) = &self.gcs_credential {
293                        builder = builder.credential(credential);
294                    }
295                    builder = builder.root(root.as_str()).bucket(bucket.as_str());
296                    let op = Operator::new(builder)?.finish();
297                    op.check().await?;
298                }
299                "azblob" => {
300                    let mut builder = Azblob::default();
301                    if let Some(account_name) = &self.azblob_account_name {
302                        builder = builder.account_name(account_name);
303                    }
304                    if let Some(azblob_account_key) = &self.azblob_account_key {
305                        builder = builder.account_key(azblob_account_key);
306                    }
307                    if let Some(azblob_endpoint_url) = &self.azblob_endpoint_url {
308                        builder = builder.endpoint(azblob_endpoint_url);
309                    }
310                    builder = builder.root(root.as_str()).container(bucket.as_str());
311                    let op = Operator::new(builder)?.finish();
312                    op.check().await?;
313                }
314                _ => {
315                    bail!("Unsupported scheme: {}", scheme);
316                }
317            }
318        }
319
320        if self.hosted_catalog.unwrap_or(false) {
321            // If `hosted_catalog` is set, we don't need to test the catalog, but just ensure no catalog fields are set.
322            if self.catalog_type.is_some() {
323                bail!("`catalog.type` must not be set when `hosted_catalog` is set");
324            }
325            if self.catalog_uri.is_some() {
326                bail!("`catalog.uri` must not be set when `hosted_catalog` is set");
327            }
328            if self.catalog_name.is_some() {
329                bail!("`catalog.name` must not be set when `hosted_catalog` is set");
330            }
331            if self.jdbc_user.is_some() {
332                bail!("`catalog.jdbc.user` must not be set when `hosted_catalog` is set");
333            }
334            if self.jdbc_password.is_some() {
335                bail!("`catalog.jdbc.password` must not be set when `hosted_catalog` is set");
336            }
337            return Ok(());
338        }
339
340        if self.catalog_type.is_none() {
341            bail!("`catalog.type` must be set");
342        }
343
344        // Test catalog
345        let iceberg_common = IcebergCommon {
346            catalog_type: self.catalog_type.clone(),
347            region: self.region.clone(),
348            endpoint: self.endpoint.clone(),
349            access_key: self.access_key.clone(),
350            secret_key: self.secret_key.clone(),
351            gcs_credential: self.gcs_credential.clone(),
352            azblob_account_name: self.azblob_account_name.clone(),
353            azblob_account_key: self.azblob_account_key.clone(),
354            azblob_endpoint_url: self.azblob_endpoint_url.clone(),
355            warehouse_path: self.warehouse_path.clone(),
356            glue_id: self.glue_id.clone(),
357            catalog_name: self.catalog_name.clone(),
358            catalog_uri: self.catalog_uri.clone(),
359            credential: self.credential.clone(),
360
361            token: self.token.clone(),
362            oauth2_server_uri: self.oauth2_server_uri.clone(),
363            scope: self.scope.clone(),
364            rest_signing_region: self.rest_signing_region.clone(),
365            rest_signing_name: self.rest_signing_name.clone(),
366            rest_sigv4_enabled: self.rest_sigv4_enabled,
367            path_style_access: self.path_style_access,
368            database_name: Some("test_database".to_owned()),
369            table_name: "test_table".to_owned(),
370            enable_config_load: Some(false),
371            hosted_catalog: self.hosted_catalog,
372        };
373
374        let mut java_map = HashMap::new();
375        if let Some(jdbc_user) = &self.jdbc_user {
376            java_map.insert("jdbc.user".to_owned(), jdbc_user.to_owned());
377        }
378        if let Some(jdbc_password) = &self.jdbc_password {
379            java_map.insert("jdbc.password".to_owned(), jdbc_password.to_owned());
380        }
381        let catalog = iceberg_common.create_catalog(&java_map).await?;
382        // test catalog by `table_exists` api
383        catalog
384            .table_exists(&iceberg_common.full_table_name()?)
385            .await?;
386        Ok(())
387    }
388}
389
390#[serde_as]
391#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
392#[serde(deny_unknown_fields)]
393pub struct ConfluentSchemaRegistryConnection {
394    #[serde(rename = "schema.registry")]
395    pub url: String,
396    // ref `SchemaRegistryAuth`
397    #[serde(rename = "schema.registry.username")]
398    pub username: Option<String>,
399    #[serde(rename = "schema.registry.password")]
400    pub password: Option<String>,
401}
402
403#[async_trait]
404impl Connection for ConfluentSchemaRegistryConnection {
405    async fn validate_connection(&self) -> ConnectorResult<()> {
406        // GET /config to validate the connection
407        let client = ConfluentSchemaRegistryClient::try_from(self)?;
408        client.validate_connection().await?;
409        Ok(())
410    }
411}
412
413impl EnforceSecret for ConfluentSchemaRegistryConnection {
414    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
415        "schema.registry.password",
416    };
417}
418
419#[derive(Debug, Clone, Deserialize, PartialEq, Hash, Eq)]
420pub struct ElasticsearchConnection(pub BTreeMap<String, String>);
421
422#[async_trait]
423impl Connection for ElasticsearchConnection {
424    async fn validate_connection(&self) -> ConnectorResult<()> {
425        const CONNECTOR: &str = "elasticsearch";
426
427        let config = ElasticSearchOpenSearchConfig::try_from(self)?;
428        let client = config.build_client(CONNECTOR)?;
429        client.ping().await?;
430        Ok(())
431    }
432}
433
434impl EnforceSecret for ElasticsearchConnection {
435    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
436        "elasticsearch.password",
437    };
438}