risingwave_connector/connector_common/
connection.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::time::Duration;
17
18use anyhow::Context;
19use opendal::Operator;
20use opendal::services::{Azblob, Gcs, S3};
21use phf::{Set, phf_set};
22use rdkafka::ClientConfig;
23use rdkafka::consumer::{BaseConsumer, Consumer};
24use risingwave_common::bail;
25use risingwave_common::secret::LocalSecretManager;
26use risingwave_common::util::env_var::env_var_is_true;
27use risingwave_pb::catalog::PbConnection;
28use serde_derive::Deserialize;
29use serde_with::serde_as;
30use tonic::async_trait;
31use url::Url;
32use with_options::WithOptions;
33
34use crate::connector_common::common::DISABLE_DEFAULT_CREDENTIAL;
35use crate::connector_common::{
36    AwsAuthProps, IcebergCommon, KafkaConnectionProps, KafkaPrivateLinkCommon,
37};
38use crate::deserialize_optional_bool_from_string;
39use crate::enforce_secret::EnforceSecret;
40use crate::error::ConnectorResult;
41use crate::schema::schema_registry::Client as ConfluentSchemaRegistryClient;
42use crate::sink::elasticsearch_opensearch::elasticsearch_opensearch_config::ElasticSearchOpenSearchConfig;
43use crate::source::build_connection;
44use crate::source::kafka::{KafkaContextCommon, RwConsumerContext};
45
46pub const SCHEMA_REGISTRY_CONNECTION_TYPE: &str = "schema_registry";
47
48// All XxxConnection structs should implement this trait as well as EnforceSecretOnCloud trait.
49#[async_trait]
50pub trait Connection: Send {
51    async fn validate_connection(&self) -> ConnectorResult<()>;
52}
53
54#[serde_as]
55#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq)]
56#[serde(deny_unknown_fields)]
57pub struct KafkaConnection {
58    #[serde(flatten)]
59    pub inner: KafkaConnectionProps,
60    #[serde(flatten)]
61    pub kafka_private_link_common: KafkaPrivateLinkCommon,
62    #[serde(flatten)]
63    pub aws_auth_props: AwsAuthProps,
64}
65
66impl EnforceSecret for KafkaConnection {
67    fn enforce_secret<'a>(prop_iter: impl Iterator<Item = &'a str>) -> ConnectorResult<()> {
68        for prop in prop_iter {
69            KafkaConnectionProps::enforce_one(prop)?;
70            AwsAuthProps::enforce_one(prop)?;
71        }
72        Ok(())
73    }
74}
75
76pub async fn validate_connection(connection: &PbConnection) -> ConnectorResult<()> {
77    if let Some(ref info) = connection.info {
78        match info {
79            risingwave_pb::catalog::connection::Info::ConnectionParams(cp) => {
80                let options = cp.properties.clone().into_iter().collect();
81                let secret_refs = cp.secret_refs.clone().into_iter().collect();
82                let props_secret_resolved =
83                    LocalSecretManager::global().fill_secrets(options, secret_refs)?;
84                let connection = build_connection(cp.connection_type(), props_secret_resolved)?;
85                connection.validate_connection().await?
86            }
87            risingwave_pb::catalog::connection::Info::PrivateLinkService(_) => unreachable!(),
88        }
89    }
90    Ok(())
91}
92
93#[async_trait]
94impl Connection for KafkaConnection {
95    async fn validate_connection(&self) -> ConnectorResult<()> {
96        let client = self.build_client().await?;
97        // describe cluster here
98        client.fetch_metadata(None, Duration::from_secs(10)).await?;
99        Ok(())
100    }
101}
102
103impl KafkaConnection {
104    async fn build_client(&self) -> ConnectorResult<BaseConsumer<RwConsumerContext>> {
105        let mut config = ClientConfig::new();
106        let bootstrap_servers = &self.inner.brokers;
107        let broker_rewrite_map = self.kafka_private_link_common.broker_rewrite_map.clone();
108        config.set("bootstrap.servers", bootstrap_servers);
109        self.inner.set_security_properties(&mut config);
110
111        // dup with Kafka Enumerator
112        let ctx_common = KafkaContextCommon::new(
113            broker_rewrite_map,
114            None,
115            None,
116            self.aws_auth_props.clone(),
117            self.inner.is_aws_msk_iam(),
118        )
119        .await?;
120        let client_ctx = RwConsumerContext::new(ctx_common);
121        let client: BaseConsumer<RwConsumerContext> =
122            config.create_with_context(client_ctx).await?;
123        if self.inner.is_aws_msk_iam() {
124            #[cfg(not(madsim))]
125            client.poll(Duration::from_secs(10)); // note: this is a blocking call
126            #[cfg(madsim)]
127            client.poll(Duration::from_secs(10)).await;
128        }
129        Ok(client)
130    }
131}
132
133#[serde_as]
134#[derive(Debug, Clone, PartialEq, Eq, Deserialize, WithOptions)]
135#[serde(deny_unknown_fields)]
136pub struct IcebergConnection {
137    #[serde(rename = "catalog.type")]
138    pub catalog_type: Option<String>,
139    #[serde(rename = "s3.region")]
140    pub region: Option<String>,
141    #[serde(rename = "s3.endpoint")]
142    pub endpoint: Option<String>,
143    #[serde(rename = "s3.access.key")]
144    pub access_key: Option<String>,
145    #[serde(rename = "s3.secret.key")]
146    pub secret_key: Option<String>,
147
148    #[serde(rename = "gcs.credential")]
149    pub gcs_credential: Option<String>,
150
151    #[serde(rename = "azblob.account_name")]
152    pub azblob_account_name: Option<String>,
153    #[serde(rename = "azblob.account_key")]
154    pub azblob_account_key: Option<String>,
155    #[serde(rename = "azblob.endpoint_url")]
156    pub azblob_endpoint_url: Option<String>,
157
158    /// Path of iceberg warehouse.
159    #[serde(rename = "warehouse.path")]
160    pub warehouse_path: Option<String>,
161    /// Catalog id, can be omitted for storage catalog or when
162    /// caller's AWS account ID matches glue id
163    #[serde(rename = "glue.id")]
164    pub glue_id: Option<String>,
165    /// Catalog name, default value is risingwave.
166    #[serde(rename = "catalog.name")]
167    pub catalog_name: Option<String>,
168    /// URI of iceberg catalog, only applicable in rest catalog.
169    #[serde(rename = "catalog.uri")]
170    pub catalog_uri: Option<String>,
171    /// Credential for accessing iceberg catalog, only applicable in rest catalog.
172    /// A credential to exchange for a token in the OAuth2 client credentials flow.
173    #[serde(rename = "catalog.credential")]
174    pub credential: Option<String>,
175    /// token for accessing iceberg catalog, only applicable in rest catalog.
176    /// A Bearer token which will be used for interaction with the server.
177    #[serde(rename = "catalog.token")]
178    pub token: Option<String>,
179    /// `oauth2_server_uri` for accessing iceberg catalog, only applicable in rest catalog.
180    /// Token endpoint URI to fetch token from if the Rest Catalog is not the authorization server.
181    #[serde(rename = "catalog.oauth2_server_uri")]
182    pub oauth2_server_uri: Option<String>,
183    /// scope for accessing iceberg catalog, only applicable in rest catalog.
184    /// Additional scope for OAuth2.
185    #[serde(rename = "catalog.scope")]
186    pub scope: Option<String>,
187
188    /// The signing region to use when signing requests to the REST catalog.
189    #[serde(rename = "catalog.rest.signing_region")]
190    pub rest_signing_region: Option<String>,
191
192    /// The signing name to use when signing requests to the REST catalog.
193    #[serde(rename = "catalog.rest.signing_name")]
194    pub rest_signing_name: Option<String>,
195
196    /// Whether to use SigV4 for signing requests to the REST catalog.
197    #[serde(
198        rename = "catalog.rest.sigv4_enabled",
199        default,
200        deserialize_with = "deserialize_optional_bool_from_string"
201    )]
202    pub rest_sigv4_enabled: Option<bool>,
203
204    #[serde(
205        rename = "s3.path.style.access",
206        default,
207        deserialize_with = "deserialize_optional_bool_from_string"
208    )]
209    pub path_style_access: Option<bool>,
210
211    #[serde(rename = "catalog.jdbc.user")]
212    pub jdbc_user: Option<String>,
213
214    #[serde(rename = "catalog.jdbc.password")]
215    pub jdbc_password: Option<String>,
216
217    /// Enable config load. This parameter set to true will load warehouse credentials from the environment. Only allowed to be used in a self-hosted environment.
218    #[serde(default, deserialize_with = "deserialize_optional_bool_from_string")]
219    pub enable_config_load: Option<bool>,
220
221    /// This is only used by iceberg engine to enable the hosted catalog.
222    #[serde(
223        rename = "hosted_catalog",
224        default,
225        deserialize_with = "deserialize_optional_bool_from_string"
226    )]
227    pub hosted_catalog: Option<bool>,
228}
229
230impl EnforceSecret for IcebergConnection {
231    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
232        "s3.access.key",
233        "s3.secret.key",
234        "gcs.credential",
235        "catalog.token",
236    };
237}
238
239#[async_trait]
240impl Connection for IcebergConnection {
241    async fn validate_connection(&self) -> ConnectorResult<()> {
242        let info = match &self.warehouse_path {
243            Some(warehouse_path) => {
244                let is_s3_tables = warehouse_path.starts_with("arn:aws:s3tables");
245                let url = Url::parse(warehouse_path);
246                if (url.is_err() || is_s3_tables)
247                    && matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust"))
248                {
249                    // If the warehouse path is not a valid URL, it could be a warehouse name in rest catalog,
250                    // Or it could be a s3tables path, which is not a valid URL but a valid warehouse path,
251                    // so we allow it to pass here.
252                    None
253                } else {
254                    let url =
255                        url.with_context(|| format!("Invalid warehouse path: {}", warehouse_path))?;
256                    let bucket = url
257                        .host_str()
258                        .with_context(|| {
259                            format!("Invalid s3 path: {}, bucket is missing", warehouse_path)
260                        })?
261                        .to_owned();
262                    let root = url.path().trim_start_matches('/').to_owned();
263                    Some((url.scheme().to_owned(), bucket, root))
264                }
265            }
266            None => {
267                if matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust")) {
268                    None
269                } else {
270                    bail!("`warehouse.path` must be set");
271                }
272            }
273        };
274
275        // Test warehouse
276        if let Some((scheme, bucket, root)) = info {
277            match scheme.as_str() {
278                "s3" | "s3a" => {
279                    let mut builder = S3::default();
280                    if let Some(region) = &self.region {
281                        builder = builder.region(region);
282                    }
283                    if let Some(endpoint) = &self.endpoint {
284                        builder = builder.endpoint(endpoint);
285                    }
286                    if let Some(access_key) = &self.access_key {
287                        builder = builder.access_key_id(access_key);
288                    }
289                    if let Some(secret_key) = &self.secret_key {
290                        builder = builder.secret_access_key(secret_key);
291                    }
292                    builder = builder.root(root.as_str()).bucket(bucket.as_str());
293                    let op = Operator::new(builder)?.finish();
294                    op.check().await?;
295                }
296                "gs" | "gcs" => {
297                    let mut builder = Gcs::default();
298                    if let Some(credential) = &self.gcs_credential {
299                        builder = builder.credential(credential);
300                    }
301                    builder = builder.root(root.as_str()).bucket(bucket.as_str());
302                    let op = Operator::new(builder)?.finish();
303                    op.check().await?;
304                }
305                "azblob" => {
306                    let mut builder = Azblob::default();
307                    if let Some(account_name) = &self.azblob_account_name {
308                        builder = builder.account_name(account_name);
309                    }
310                    if let Some(azblob_account_key) = &self.azblob_account_key {
311                        builder = builder.account_key(azblob_account_key);
312                    }
313                    if let Some(azblob_endpoint_url) = &self.azblob_endpoint_url {
314                        builder = builder.endpoint(azblob_endpoint_url);
315                    }
316                    builder = builder.root(root.as_str()).container(bucket.as_str());
317                    let op = Operator::new(builder)?.finish();
318                    op.check().await?;
319                }
320                _ => {
321                    bail!("Unsupported scheme: {}", scheme);
322                }
323            }
324        }
325
326        if env_var_is_true(DISABLE_DEFAULT_CREDENTIAL)
327            && matches!(self.enable_config_load, Some(true))
328        {
329            bail!("`enable_config_load` can't be enabled in this environment");
330        }
331
332        if self.hosted_catalog.unwrap_or(false) {
333            // If `hosted_catalog` is set, we don't need to test the catalog, but just ensure no catalog fields are set.
334            if self.catalog_type.is_some() {
335                bail!("`catalog.type` must not be set when `hosted_catalog` is set");
336            }
337            if self.catalog_uri.is_some() {
338                bail!("`catalog.uri` must not be set when `hosted_catalog` is set");
339            }
340            if self.catalog_name.is_some() {
341                bail!("`catalog.name` must not be set when `hosted_catalog` is set");
342            }
343            if self.jdbc_user.is_some() {
344                bail!("`catalog.jdbc.user` must not be set when `hosted_catalog` is set");
345            }
346            if self.jdbc_password.is_some() {
347                bail!("`catalog.jdbc.password` must not be set when `hosted_catalog` is set");
348            }
349            return Ok(());
350        }
351
352        if self.catalog_type.is_none() {
353            bail!("`catalog.type` must be set");
354        }
355
356        // Test catalog
357        let iceberg_common = IcebergCommon {
358            catalog_type: self.catalog_type.clone(),
359            region: self.region.clone(),
360            endpoint: self.endpoint.clone(),
361            access_key: self.access_key.clone(),
362            secret_key: self.secret_key.clone(),
363            gcs_credential: self.gcs_credential.clone(),
364            azblob_account_name: self.azblob_account_name.clone(),
365            azblob_account_key: self.azblob_account_key.clone(),
366            azblob_endpoint_url: self.azblob_endpoint_url.clone(),
367            warehouse_path: self.warehouse_path.clone(),
368            glue_id: self.glue_id.clone(),
369            catalog_name: self.catalog_name.clone(),
370            catalog_uri: self.catalog_uri.clone(),
371            credential: self.credential.clone(),
372
373            token: self.token.clone(),
374            oauth2_server_uri: self.oauth2_server_uri.clone(),
375            scope: self.scope.clone(),
376            rest_signing_region: self.rest_signing_region.clone(),
377            rest_signing_name: self.rest_signing_name.clone(),
378            rest_sigv4_enabled: self.rest_sigv4_enabled,
379            path_style_access: self.path_style_access,
380            database_name: Some("test_database".to_owned()),
381            table_name: "test_table".to_owned(),
382            enable_config_load: self.enable_config_load,
383            hosted_catalog: self.hosted_catalog,
384        };
385
386        let mut java_map = HashMap::new();
387        if let Some(jdbc_user) = &self.jdbc_user {
388            java_map.insert("jdbc.user".to_owned(), jdbc_user.to_owned());
389        }
390        if let Some(jdbc_password) = &self.jdbc_password {
391            java_map.insert("jdbc.password".to_owned(), jdbc_password.to_owned());
392        }
393        let catalog = iceberg_common.create_catalog(&java_map).await?;
394        // test catalog by `table_exists` api
395        catalog
396            .table_exists(&iceberg_common.full_table_name()?)
397            .await?;
398        Ok(())
399    }
400}
401
402#[serde_as]
403#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
404#[serde(deny_unknown_fields)]
405pub struct ConfluentSchemaRegistryConnection {
406    #[serde(rename = "schema.registry")]
407    pub url: String,
408    // ref `SchemaRegistryAuth`
409    #[serde(rename = "schema.registry.username")]
410    pub username: Option<String>,
411    #[serde(rename = "schema.registry.password")]
412    pub password: Option<String>,
413}
414
415#[async_trait]
416impl Connection for ConfluentSchemaRegistryConnection {
417    async fn validate_connection(&self) -> ConnectorResult<()> {
418        // GET /config to validate the connection
419        let client = ConfluentSchemaRegistryClient::try_from(self)?;
420        client.validate_connection().await?;
421        Ok(())
422    }
423}
424
425impl EnforceSecret for ConfluentSchemaRegistryConnection {
426    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
427        "schema.registry.password",
428    };
429}
430
431#[derive(Debug, Clone, Deserialize, PartialEq, Hash, Eq)]
432pub struct ElasticsearchConnection(pub BTreeMap<String, String>);
433
434#[async_trait]
435impl Connection for ElasticsearchConnection {
436    async fn validate_connection(&self) -> ConnectorResult<()> {
437        const CONNECTOR: &str = "elasticsearch";
438
439        let config = ElasticSearchOpenSearchConfig::try_from(self)?;
440        let client = config.build_client(CONNECTOR)?;
441        client.ping().await?;
442        Ok(())
443    }
444}
445
446impl EnforceSecret for ElasticsearchConnection {
447    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
448        "elasticsearch.password",
449    };
450}