risingwave_connector/connector_common/
connection.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::time::Duration;
17
18use anyhow::Context;
19use opendal::Operator;
20use opendal::services::{Gcs, S3};
21use rdkafka::ClientConfig;
22use rdkafka::consumer::{BaseConsumer, Consumer};
23use risingwave_common::bail;
24use risingwave_common::secret::LocalSecretManager;
25use risingwave_pb::catalog::PbConnection;
26use serde_derive::Deserialize;
27use serde_with::serde_as;
28use tonic::async_trait;
29use url::Url;
30use with_options::WithOptions;
31
32use crate::connector_common::{
33    AwsAuthProps, IcebergCommon, KafkaConnectionProps, KafkaPrivateLinkCommon,
34};
35use crate::deserialize_optional_bool_from_string;
36use crate::error::ConnectorResult;
37use crate::schema::schema_registry::Client as ConfluentSchemaRegistryClient;
38use crate::sink::elasticsearch_opensearch::elasticsearch_opensearch_config::ElasticSearchOpenSearchConfig;
39use crate::source::build_connection;
40use crate::source::kafka::{KafkaContextCommon, RwConsumerContext};
41
42pub const SCHEMA_REGISTRY_CONNECTION_TYPE: &str = "schema_registry";
43
44#[async_trait]
45pub trait Connection: Send {
46    async fn validate_connection(&self) -> ConnectorResult<()>;
47}
48
49#[serde_as]
50#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq)]
51#[serde(deny_unknown_fields)]
52pub struct KafkaConnection {
53    #[serde(flatten)]
54    pub inner: KafkaConnectionProps,
55    #[serde(flatten)]
56    pub kafka_private_link_common: KafkaPrivateLinkCommon,
57    #[serde(flatten)]
58    pub aws_auth_props: AwsAuthProps,
59}
60
61pub async fn validate_connection(connection: &PbConnection) -> ConnectorResult<()> {
62    if let Some(ref info) = connection.info {
63        match info {
64            risingwave_pb::catalog::connection::Info::ConnectionParams(cp) => {
65                let options = cp.properties.clone().into_iter().collect();
66                let secret_refs = cp.secret_refs.clone().into_iter().collect();
67                let props_secret_resolved =
68                    LocalSecretManager::global().fill_secrets(options, secret_refs)?;
69                let connection = build_connection(cp.connection_type(), props_secret_resolved)?;
70                connection.validate_connection().await?
71            }
72            risingwave_pb::catalog::connection::Info::PrivateLinkService(_) => unreachable!(),
73        }
74    }
75    Ok(())
76}
77
78#[async_trait]
79impl Connection for KafkaConnection {
80    async fn validate_connection(&self) -> ConnectorResult<()> {
81        let client = self.build_client().await?;
82        // describe cluster here
83        client.fetch_metadata(None, Duration::from_secs(10)).await?;
84        Ok(())
85    }
86}
87
88impl KafkaConnection {
89    async fn build_client(&self) -> ConnectorResult<BaseConsumer<RwConsumerContext>> {
90        let mut config = ClientConfig::new();
91        let bootstrap_servers = &self.inner.brokers;
92        let broker_rewrite_map = self.kafka_private_link_common.broker_rewrite_map.clone();
93        config.set("bootstrap.servers", bootstrap_servers);
94        self.inner.set_security_properties(&mut config);
95
96        // dup with Kafka Enumerator
97        let ctx_common = KafkaContextCommon::new(
98            broker_rewrite_map,
99            None,
100            None,
101            self.aws_auth_props.clone(),
102            self.inner.is_aws_msk_iam(),
103        )
104        .await?;
105        let client_ctx = RwConsumerContext::new(ctx_common);
106        let client: BaseConsumer<RwConsumerContext> =
107            config.create_with_context(client_ctx).await?;
108        if self.inner.is_aws_msk_iam() {
109            #[cfg(not(madsim))]
110            client.poll(Duration::from_secs(10)); // note: this is a blocking call
111            #[cfg(madsim)]
112            client.poll(Duration::from_secs(10)).await;
113        }
114        Ok(client)
115    }
116}
117
118#[serde_as]
119#[derive(Debug, Clone, PartialEq, Eq, Deserialize, WithOptions)]
120#[serde(deny_unknown_fields)]
121pub struct IcebergConnection {
122    #[serde(rename = "catalog.type")]
123    pub catalog_type: Option<String>,
124    #[serde(rename = "s3.region")]
125    pub region: Option<String>,
126    #[serde(rename = "s3.endpoint")]
127    pub endpoint: Option<String>,
128    #[serde(rename = "s3.access.key")]
129    pub access_key: Option<String>,
130    #[serde(rename = "s3.secret.key")]
131    pub secret_key: Option<String>,
132
133    #[serde(rename = "gcs.credential")]
134    pub gcs_credential: Option<String>,
135
136    /// Path of iceberg warehouse.
137    #[serde(rename = "warehouse.path")]
138    pub warehouse_path: Option<String>,
139    /// Catalog id, can be omitted for storage catalog or when
140    /// caller's AWS account ID matches glue id
141    #[serde(rename = "glue.id")]
142    pub glue_id: Option<String>,
143    /// Catalog name, default value is risingwave.
144    #[serde(rename = "catalog.name")]
145    pub catalog_name: Option<String>,
146    /// URI of iceberg catalog, only applicable in rest catalog.
147    #[serde(rename = "catalog.uri")]
148    pub catalog_uri: Option<String>,
149    /// Credential for accessing iceberg catalog, only applicable in rest catalog.
150    /// A credential to exchange for a token in the OAuth2 client credentials flow.
151    #[serde(rename = "catalog.credential")]
152    pub credential: Option<String>,
153    /// token for accessing iceberg catalog, only applicable in rest catalog.
154    /// A Bearer token which will be used for interaction with the server.
155    #[serde(rename = "catalog.token")]
156    pub token: Option<String>,
157    /// `oauth2_server_uri` for accessing iceberg catalog, only applicable in rest catalog.
158    /// Token endpoint URI to fetch token from if the Rest Catalog is not the authorization server.
159    #[serde(rename = "catalog.oauth2_server_uri")]
160    pub oauth2_server_uri: Option<String>,
161    /// scope for accessing iceberg catalog, only applicable in rest catalog.
162    /// Additional scope for OAuth2.
163    #[serde(rename = "catalog.scope")]
164    pub scope: Option<String>,
165
166    /// The signing region to use when signing requests to the REST catalog.
167    #[serde(rename = "catalog.rest.signing_region")]
168    pub rest_signing_region: Option<String>,
169
170    /// The signing name to use when signing requests to the REST catalog.
171    #[serde(rename = "catalog.rest.signing_name")]
172    pub rest_signing_name: Option<String>,
173
174    /// Whether to use SigV4 for signing requests to the REST catalog.
175    #[serde(
176        rename = "catalog.rest.sigv4_enabled",
177        default,
178        deserialize_with = "deserialize_optional_bool_from_string"
179    )]
180    pub rest_sigv4_enabled: Option<bool>,
181
182    #[serde(
183        rename = "s3.path.style.access",
184        default,
185        deserialize_with = "deserialize_optional_bool_from_string"
186    )]
187    pub path_style_access: Option<bool>,
188
189    #[serde(rename = "catalog.jdbc.user")]
190    pub jdbc_user: Option<String>,
191
192    #[serde(rename = "catalog.jdbc.password")]
193    pub jdbc_password: Option<String>,
194}
195
196#[async_trait]
197impl Connection for IcebergConnection {
198    async fn validate_connection(&self) -> ConnectorResult<()> {
199        let info = match &self.warehouse_path {
200            Some(warehouse_path) => {
201                let is_s3_tables = warehouse_path.starts_with("arn:aws:s3tables");
202                let url = Url::parse(warehouse_path);
203                if (url.is_err() || is_s3_tables)
204                    && matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust"))
205                {
206                    // If the warehouse path is not a valid URL, it could be a warehouse name in rest catalog,
207                    // Or it could be a s3tables path, which is not a valid URL but a valid warehouse path,
208                    // so we allow it to pass here.
209                    None
210                } else {
211                    let url =
212                        url.with_context(|| format!("Invalid warehouse path: {}", warehouse_path))?;
213                    let bucket = url
214                        .host_str()
215                        .with_context(|| {
216                            format!("Invalid s3 path: {}, bucket is missing", warehouse_path)
217                        })?
218                        .to_owned();
219                    let root = url.path().trim_start_matches('/').to_owned();
220                    Some((url.scheme().to_owned(), bucket, root))
221                }
222            }
223            None => {
224                if matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust")) {
225                    None
226                } else {
227                    bail!("`warehouse.path` must be set");
228                }
229            }
230        };
231
232        // Test warehouse
233        if let Some((scheme, bucket, root)) = info {
234            match scheme.as_str() {
235                "s3" | "s3a" => {
236                    let mut builder = S3::default();
237                    if let Some(region) = &self.region {
238                        builder = builder.region(region);
239                    }
240                    if let Some(endpoint) = &self.endpoint {
241                        builder = builder.endpoint(endpoint);
242                    }
243                    if let Some(access_key) = &self.access_key {
244                        builder = builder.access_key_id(access_key);
245                    }
246                    if let Some(secret_key) = &self.secret_key {
247                        builder = builder.secret_access_key(secret_key);
248                    }
249                    builder = builder.root(root.as_str()).bucket(bucket.as_str());
250                    let op = Operator::new(builder)?.finish();
251                    op.check().await?;
252                }
253                "gs" | "gcs" => {
254                    let mut builder = Gcs::default();
255                    if let Some(credential) = &self.gcs_credential {
256                        builder = builder.credential(credential);
257                    }
258                    builder = builder.root(root.as_str()).bucket(bucket.as_str());
259                    let op = Operator::new(builder)?.finish();
260                    op.check().await?;
261                }
262                _ => {
263                    bail!("Unsupported scheme: {}", scheme);
264                }
265            }
266        }
267
268        if self.catalog_type.is_none() {
269            bail!("`catalog.type` must be set");
270        }
271
272        // Test catalog
273        let iceberg_common = IcebergCommon {
274            catalog_type: self.catalog_type.clone(),
275            region: self.region.clone(),
276            endpoint: self.endpoint.clone(),
277            access_key: self.access_key.clone(),
278            secret_key: self.secret_key.clone(),
279            gcs_credential: self.gcs_credential.clone(),
280            warehouse_path: self.warehouse_path.clone(),
281            glue_id: self.glue_id.clone(),
282            catalog_name: self.catalog_name.clone(),
283            catalog_uri: self.catalog_uri.clone(),
284            credential: self.credential.clone(),
285            token: self.token.clone(),
286            oauth2_server_uri: self.oauth2_server_uri.clone(),
287            scope: self.scope.clone(),
288            rest_signing_region: self.rest_signing_region.clone(),
289            rest_signing_name: self.rest_signing_name.clone(),
290            rest_sigv4_enabled: self.rest_sigv4_enabled,
291            path_style_access: self.path_style_access,
292            database_name: Some("test_database".to_owned()),
293            table_name: "test_table".to_owned(),
294            enable_config_load: Some(false),
295        };
296
297        let mut java_map = HashMap::new();
298        if let Some(jdbc_user) = &self.jdbc_user {
299            java_map.insert("jdbc.user".to_owned(), jdbc_user.to_owned());
300        }
301        if let Some(jdbc_password) = &self.jdbc_password {
302            java_map.insert("jdbc.password".to_owned(), jdbc_password.to_owned());
303        }
304        let catalog = iceberg_common.create_catalog(&java_map).await?;
305        // test catalog by `table_exists` api
306        catalog
307            .table_exists(&iceberg_common.full_table_name()?)
308            .await?;
309        Ok(())
310    }
311}
312
313#[serde_as]
314#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
315#[serde(deny_unknown_fields)]
316pub struct ConfluentSchemaRegistryConnection {
317    #[serde(rename = "schema.registry")]
318    pub url: String,
319    // ref `SchemaRegistryAuth`
320    #[serde(rename = "schema.registry.username")]
321    pub username: Option<String>,
322    #[serde(rename = "schema.registry.password")]
323    pub password: Option<String>,
324}
325
326#[async_trait]
327impl Connection for ConfluentSchemaRegistryConnection {
328    async fn validate_connection(&self) -> ConnectorResult<()> {
329        // GET /config to validate the connection
330        let client = ConfluentSchemaRegistryClient::try_from(self)?;
331        client.validate_connection().await?;
332        Ok(())
333    }
334}
335
336#[derive(Debug, Clone, Deserialize, PartialEq, Hash, Eq)]
337pub struct ElasticsearchConnection(pub BTreeMap<String, String>);
338
339#[async_trait]
340impl Connection for ElasticsearchConnection {
341    async fn validate_connection(&self) -> ConnectorResult<()> {
342        const CONNECTOR: &str = "elasticsearch";
343
344        let config = ElasticSearchOpenSearchConfig::try_from(self)?;
345        let client = config.build_client(CONNECTOR)?;
346        client.ping().await?;
347        Ok(())
348    }
349}