risingwave_connector/connector_common/
connection.rs1use std::collections::{BTreeMap, HashMap};
16use std::time::Duration;
17
18use anyhow::Context;
19use opendal::Operator;
20use opendal::services::{Azblob, Gcs, S3};
21use phf::{Set, phf_set};
22use rdkafka::ClientConfig;
23use rdkafka::consumer::{BaseConsumer, Consumer};
24use risingwave_common::bail;
25use risingwave_common::secret::LocalSecretManager;
26use risingwave_pb::catalog::PbConnection;
27use serde_derive::Deserialize;
28use serde_with::serde_as;
29use tonic::async_trait;
30use url::Url;
31use with_options::WithOptions;
32
33use crate::connector_common::{
34 AwsAuthProps, IcebergCommon, KafkaConnectionProps, KafkaPrivateLinkCommon,
35};
36use crate::deserialize_optional_bool_from_string;
37use crate::enforce_secret::EnforceSecret;
38use crate::error::ConnectorResult;
39use crate::schema::schema_registry::Client as ConfluentSchemaRegistryClient;
40use crate::sink::elasticsearch_opensearch::elasticsearch_opensearch_config::ElasticSearchOpenSearchConfig;
41use crate::source::build_connection;
42use crate::source::kafka::{KafkaContextCommon, RwConsumerContext};
43
44pub const SCHEMA_REGISTRY_CONNECTION_TYPE: &str = "schema_registry";
45
46#[async_trait]
48pub trait Connection: Send {
49 async fn validate_connection(&self) -> ConnectorResult<()>;
50}
51
52#[serde_as]
53#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq)]
54#[serde(deny_unknown_fields)]
55pub struct KafkaConnection {
56 #[serde(flatten)]
57 pub inner: KafkaConnectionProps,
58 #[serde(flatten)]
59 pub kafka_private_link_common: KafkaPrivateLinkCommon,
60 #[serde(flatten)]
61 pub aws_auth_props: AwsAuthProps,
62}
63
64impl EnforceSecret for KafkaConnection {
65 fn enforce_secret<'a>(prop_iter: impl Iterator<Item = &'a str>) -> ConnectorResult<()> {
66 for prop in prop_iter {
67 KafkaConnectionProps::enforce_one(prop)?;
68 AwsAuthProps::enforce_one(prop)?;
69 }
70 Ok(())
71 }
72}
73
74pub async fn validate_connection(connection: &PbConnection) -> ConnectorResult<()> {
75 if let Some(ref info) = connection.info {
76 match info {
77 risingwave_pb::catalog::connection::Info::ConnectionParams(cp) => {
78 let options = cp.properties.clone().into_iter().collect();
79 let secret_refs = cp.secret_refs.clone().into_iter().collect();
80 let props_secret_resolved =
81 LocalSecretManager::global().fill_secrets(options, secret_refs)?;
82 let connection = build_connection(cp.connection_type(), props_secret_resolved)?;
83 connection.validate_connection().await?
84 }
85 risingwave_pb::catalog::connection::Info::PrivateLinkService(_) => unreachable!(),
86 }
87 }
88 Ok(())
89}
90
91#[async_trait]
92impl Connection for KafkaConnection {
93 async fn validate_connection(&self) -> ConnectorResult<()> {
94 let client = self.build_client().await?;
95 client.fetch_metadata(None, Duration::from_secs(10)).await?;
97 Ok(())
98 }
99}
100
101impl KafkaConnection {
102 async fn build_client(&self) -> ConnectorResult<BaseConsumer<RwConsumerContext>> {
103 let mut config = ClientConfig::new();
104 let bootstrap_servers = &self.inner.brokers;
105 let broker_rewrite_map = self.kafka_private_link_common.broker_rewrite_map.clone();
106 config.set("bootstrap.servers", bootstrap_servers);
107 self.inner.set_security_properties(&mut config);
108
109 let ctx_common = KafkaContextCommon::new(
111 broker_rewrite_map,
112 None,
113 None,
114 self.aws_auth_props.clone(),
115 self.inner.is_aws_msk_iam(),
116 )
117 .await?;
118 let client_ctx = RwConsumerContext::new(ctx_common);
119 let client: BaseConsumer<RwConsumerContext> =
120 config.create_with_context(client_ctx).await?;
121 if self.inner.is_aws_msk_iam() {
122 #[cfg(not(madsim))]
123 client.poll(Duration::from_secs(10)); #[cfg(madsim)]
125 client.poll(Duration::from_secs(10)).await;
126 }
127 Ok(client)
128 }
129}
130
131#[serde_as]
132#[derive(Debug, Clone, PartialEq, Eq, Deserialize, WithOptions)]
133#[serde(deny_unknown_fields)]
134pub struct IcebergConnection {
135 #[serde(rename = "catalog.type")]
136 pub catalog_type: Option<String>,
137 #[serde(rename = "s3.region")]
138 pub region: Option<String>,
139 #[serde(rename = "s3.endpoint")]
140 pub endpoint: Option<String>,
141 #[serde(rename = "s3.access.key")]
142 pub access_key: Option<String>,
143 #[serde(rename = "s3.secret.key")]
144 pub secret_key: Option<String>,
145
146 #[serde(rename = "gcs.credential")]
147 pub gcs_credential: Option<String>,
148
149 #[serde(rename = "azblob.account_name")]
150 pub azblob_account_name: Option<String>,
151 #[serde(rename = "azblob.account_key")]
152 pub azblob_account_key: Option<String>,
153 #[serde(rename = "azblob.endpoint_url")]
154 pub azblob_endpoint_url: Option<String>,
155
156 #[serde(rename = "warehouse.path")]
158 pub warehouse_path: Option<String>,
159 #[serde(rename = "glue.id")]
162 pub glue_id: Option<String>,
163 #[serde(rename = "catalog.name")]
165 pub catalog_name: Option<String>,
166 #[serde(rename = "catalog.uri")]
168 pub catalog_uri: Option<String>,
169 #[serde(rename = "catalog.credential")]
172 pub credential: Option<String>,
173 #[serde(rename = "catalog.token")]
176 pub token: Option<String>,
177 #[serde(rename = "catalog.oauth2_server_uri")]
180 pub oauth2_server_uri: Option<String>,
181 #[serde(rename = "catalog.scope")]
184 pub scope: Option<String>,
185
186 #[serde(rename = "catalog.rest.signing_region")]
188 pub rest_signing_region: Option<String>,
189
190 #[serde(rename = "catalog.rest.signing_name")]
192 pub rest_signing_name: Option<String>,
193
194 #[serde(
196 rename = "catalog.rest.sigv4_enabled",
197 default,
198 deserialize_with = "deserialize_optional_bool_from_string"
199 )]
200 pub rest_sigv4_enabled: Option<bool>,
201
202 #[serde(
203 rename = "s3.path.style.access",
204 default,
205 deserialize_with = "deserialize_optional_bool_from_string"
206 )]
207 pub path_style_access: Option<bool>,
208
209 #[serde(rename = "catalog.jdbc.user")]
210 pub jdbc_user: Option<String>,
211
212 #[serde(rename = "catalog.jdbc.password")]
213 pub jdbc_password: Option<String>,
214
215 #[serde(
217 rename = "hosted_catalog",
218 default,
219 deserialize_with = "deserialize_optional_bool_from_string"
220 )]
221 pub hosted_catalog: Option<bool>,
222}
223
224impl EnforceSecret for IcebergConnection {
225 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
226 "s3.access.key",
227 "s3.secret.key",
228 "gcs.credential",
229 "catalog.token",
230 };
231}
232
233#[async_trait]
234impl Connection for IcebergConnection {
235 async fn validate_connection(&self) -> ConnectorResult<()> {
236 let info = match &self.warehouse_path {
237 Some(warehouse_path) => {
238 let is_s3_tables = warehouse_path.starts_with("arn:aws:s3tables");
239 let url = Url::parse(warehouse_path);
240 if (url.is_err() || is_s3_tables)
241 && matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust"))
242 {
243 None
247 } else {
248 let url =
249 url.with_context(|| format!("Invalid warehouse path: {}", warehouse_path))?;
250 let bucket = url
251 .host_str()
252 .with_context(|| {
253 format!("Invalid s3 path: {}, bucket is missing", warehouse_path)
254 })?
255 .to_owned();
256 let root = url.path().trim_start_matches('/').to_owned();
257 Some((url.scheme().to_owned(), bucket, root))
258 }
259 }
260 None => {
261 if matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust")) {
262 None
263 } else {
264 bail!("`warehouse.path` must be set");
265 }
266 }
267 };
268
269 if let Some((scheme, bucket, root)) = info {
271 match scheme.as_str() {
272 "s3" | "s3a" => {
273 let mut builder = S3::default();
274 if let Some(region) = &self.region {
275 builder = builder.region(region);
276 }
277 if let Some(endpoint) = &self.endpoint {
278 builder = builder.endpoint(endpoint);
279 }
280 if let Some(access_key) = &self.access_key {
281 builder = builder.access_key_id(access_key);
282 }
283 if let Some(secret_key) = &self.secret_key {
284 builder = builder.secret_access_key(secret_key);
285 }
286 builder = builder.root(root.as_str()).bucket(bucket.as_str());
287 let op = Operator::new(builder)?.finish();
288 op.check().await?;
289 }
290 "gs" | "gcs" => {
291 let mut builder = Gcs::default();
292 if let Some(credential) = &self.gcs_credential {
293 builder = builder.credential(credential);
294 }
295 builder = builder.root(root.as_str()).bucket(bucket.as_str());
296 let op = Operator::new(builder)?.finish();
297 op.check().await?;
298 }
299 "azblob" => {
300 let mut builder = Azblob::default();
301 if let Some(account_name) = &self.azblob_account_name {
302 builder = builder.account_name(account_name);
303 }
304 if let Some(azblob_account_key) = &self.azblob_account_key {
305 builder = builder.account_key(azblob_account_key);
306 }
307 if let Some(azblob_endpoint_url) = &self.azblob_endpoint_url {
308 builder = builder.endpoint(azblob_endpoint_url);
309 }
310 builder = builder.root(root.as_str()).container(bucket.as_str());
311 let op = Operator::new(builder)?.finish();
312 op.check().await?;
313 }
314 _ => {
315 bail!("Unsupported scheme: {}", scheme);
316 }
317 }
318 }
319
320 if self.hosted_catalog.unwrap_or(false) {
321 if self.catalog_type.is_some() {
323 bail!("`catalog.type` must not be set when `hosted_catalog` is set");
324 }
325 if self.catalog_uri.is_some() {
326 bail!("`catalog.uri` must not be set when `hosted_catalog` is set");
327 }
328 if self.catalog_name.is_some() {
329 bail!("`catalog.name` must not be set when `hosted_catalog` is set");
330 }
331 if self.jdbc_user.is_some() {
332 bail!("`catalog.jdbc.user` must not be set when `hosted_catalog` is set");
333 }
334 if self.jdbc_password.is_some() {
335 bail!("`catalog.jdbc.password` must not be set when `hosted_catalog` is set");
336 }
337 return Ok(());
338 }
339
340 if self.catalog_type.is_none() {
341 bail!("`catalog.type` must be set");
342 }
343
344 let iceberg_common = IcebergCommon {
346 catalog_type: self.catalog_type.clone(),
347 region: self.region.clone(),
348 endpoint: self.endpoint.clone(),
349 access_key: self.access_key.clone(),
350 secret_key: self.secret_key.clone(),
351 gcs_credential: self.gcs_credential.clone(),
352 azblob_account_name: self.azblob_account_name.clone(),
353 azblob_account_key: self.azblob_account_key.clone(),
354 azblob_endpoint_url: self.azblob_endpoint_url.clone(),
355 warehouse_path: self.warehouse_path.clone(),
356 glue_id: self.glue_id.clone(),
357 catalog_name: self.catalog_name.clone(),
358 catalog_uri: self.catalog_uri.clone(),
359 credential: self.credential.clone(),
360
361 token: self.token.clone(),
362 oauth2_server_uri: self.oauth2_server_uri.clone(),
363 scope: self.scope.clone(),
364 rest_signing_region: self.rest_signing_region.clone(),
365 rest_signing_name: self.rest_signing_name.clone(),
366 rest_sigv4_enabled: self.rest_sigv4_enabled,
367 path_style_access: self.path_style_access,
368 database_name: Some("test_database".to_owned()),
369 table_name: "test_table".to_owned(),
370 enable_config_load: Some(false),
371 hosted_catalog: self.hosted_catalog,
372 };
373
374 let mut java_map = HashMap::new();
375 if let Some(jdbc_user) = &self.jdbc_user {
376 java_map.insert("jdbc.user".to_owned(), jdbc_user.to_owned());
377 }
378 if let Some(jdbc_password) = &self.jdbc_password {
379 java_map.insert("jdbc.password".to_owned(), jdbc_password.to_owned());
380 }
381 let catalog = iceberg_common.create_catalog(&java_map).await?;
382 catalog
384 .table_exists(&iceberg_common.full_table_name()?)
385 .await?;
386 Ok(())
387 }
388}
389
390#[serde_as]
391#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
392#[serde(deny_unknown_fields)]
393pub struct ConfluentSchemaRegistryConnection {
394 #[serde(rename = "schema.registry")]
395 pub url: String,
396 #[serde(rename = "schema.registry.username")]
398 pub username: Option<String>,
399 #[serde(rename = "schema.registry.password")]
400 pub password: Option<String>,
401}
402
403#[async_trait]
404impl Connection for ConfluentSchemaRegistryConnection {
405 async fn validate_connection(&self) -> ConnectorResult<()> {
406 let client = ConfluentSchemaRegistryClient::try_from(self)?;
408 client.validate_connection().await?;
409 Ok(())
410 }
411}
412
413impl EnforceSecret for ConfluentSchemaRegistryConnection {
414 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
415 "schema.registry.password",
416 };
417}
418
419#[derive(Debug, Clone, Deserialize, PartialEq, Hash, Eq)]
420pub struct ElasticsearchConnection(pub BTreeMap<String, String>);
421
422#[async_trait]
423impl Connection for ElasticsearchConnection {
424 async fn validate_connection(&self) -> ConnectorResult<()> {
425 const CONNECTOR: &str = "elasticsearch";
426
427 let config = ElasticSearchOpenSearchConfig::try_from(self)?;
428 let client = config.build_client(CONNECTOR)?;
429 client.ping().await?;
430 Ok(())
431 }
432}
433
434impl EnforceSecret for ElasticsearchConnection {
435 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
436 "elasticsearch.password",
437 };
438}