risingwave_connector/connector_common/
connection.rs1use std::collections::{BTreeMap, HashMap};
16use std::time::Duration;
17
18use anyhow::Context;
19use opendal::Operator;
20use opendal::services::{Azblob, Gcs, S3};
21use phf::{Set, phf_set};
22use rdkafka::ClientConfig;
23use rdkafka::config::RDKafkaLogLevel;
24use rdkafka::consumer::{BaseConsumer, Consumer};
25use risingwave_common::bail;
26use risingwave_common::secret::LocalSecretManager;
27use risingwave_common::util::env_var::env_var_is_true;
28use risingwave_pb::catalog::PbConnection;
29use serde_derive::Deserialize;
30use serde_with::serde_as;
31use tonic::async_trait;
32use url::Url;
33use with_options::WithOptions;
34
35use crate::connector_common::common::DISABLE_DEFAULT_CREDENTIAL;
36use crate::connector_common::{
37 AwsAuthProps, IcebergCommon, KafkaConnectionProps, KafkaPrivateLinkCommon,
38};
39use crate::deserialize_optional_bool_from_string;
40use crate::enforce_secret::EnforceSecret;
41use crate::error::ConnectorResult;
42use crate::schema::schema_registry::Client as ConfluentSchemaRegistryClient;
43use crate::sink::elasticsearch_opensearch::elasticsearch_opensearch_config::ElasticSearchOpenSearchConfig;
44use crate::source::build_connection;
45use crate::source::kafka::{KafkaContextCommon, RwConsumerContext};
46
47pub const SCHEMA_REGISTRY_CONNECTION_TYPE: &str = "schema_registry";
48
49#[async_trait]
51pub trait Connection: Send {
52 async fn validate_connection(&self) -> ConnectorResult<()>;
53}
54
55#[serde_as]
56#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq)]
57#[serde(deny_unknown_fields)]
58pub struct KafkaConnection {
59 #[serde(flatten)]
60 pub inner: KafkaConnectionProps,
61 #[serde(flatten)]
62 pub kafka_private_link_common: KafkaPrivateLinkCommon,
63 #[serde(flatten)]
64 pub aws_auth_props: AwsAuthProps,
65}
66
67impl EnforceSecret for KafkaConnection {
68 fn enforce_secret<'a>(prop_iter: impl Iterator<Item = &'a str>) -> ConnectorResult<()> {
69 for prop in prop_iter {
70 KafkaConnectionProps::enforce_one(prop)?;
71 AwsAuthProps::enforce_one(prop)?;
72 }
73 Ok(())
74 }
75}
76
77pub async fn validate_connection(connection: &PbConnection) -> ConnectorResult<()> {
78 if let Some(ref info) = connection.info {
79 match info {
80 risingwave_pb::catalog::connection::Info::ConnectionParams(cp) => {
81 let options = cp.properties.clone().into_iter().collect();
82 let secret_refs = cp.secret_refs.clone().into_iter().collect();
83 let props_secret_resolved =
84 LocalSecretManager::global().fill_secrets(options, secret_refs)?;
85 let connection = build_connection(cp.connection_type(), props_secret_resolved)?;
86 connection.validate_connection().await?
87 }
88 risingwave_pb::catalog::connection::Info::PrivateLinkService(_) => unreachable!(),
89 }
90 }
91 Ok(())
92}
93
94#[async_trait]
95impl Connection for KafkaConnection {
96 async fn validate_connection(&self) -> ConnectorResult<()> {
97 let client = self.build_client().await?;
98 client.fetch_metadata(None, Duration::from_secs(10)).await?;
100 Ok(())
101 }
102}
103
104pub fn read_kafka_log_level() -> Option<RDKafkaLogLevel> {
105 let log_level = std::env::var("RISINGWAVE_KAFKA_LOG_LEVEL").ok()?;
106 match log_level.to_uppercase().as_str() {
107 "DEBUG" => Some(RDKafkaLogLevel::Debug),
108 "INFO" => Some(RDKafkaLogLevel::Info),
109 "WARN" => Some(RDKafkaLogLevel::Warning),
110 "ERROR" => Some(RDKafkaLogLevel::Error),
111 "CRITICAL" => Some(RDKafkaLogLevel::Critical),
112 "EMERG" => Some(RDKafkaLogLevel::Emerg),
113 "ALERT" => Some(RDKafkaLogLevel::Alert),
114 "NOTICE" => Some(RDKafkaLogLevel::Notice),
115 _ => None,
116 }
117}
118
119impl KafkaConnection {
120 async fn build_client(&self) -> ConnectorResult<BaseConsumer<RwConsumerContext>> {
121 let mut config = ClientConfig::new();
122 let bootstrap_servers = &self.inner.brokers;
123 let broker_rewrite_map = self.kafka_private_link_common.broker_rewrite_map.clone();
124 config.set("bootstrap.servers", bootstrap_servers);
125 self.inner.set_security_properties(&mut config);
126
127 let ctx_common = KafkaContextCommon::new(
129 broker_rewrite_map,
130 None,
131 None,
132 self.aws_auth_props.clone(),
133 self.inner.is_aws_msk_iam(),
134 )
135 .await?;
136 let client_ctx = RwConsumerContext::new(ctx_common);
137
138 if let Some(log_level) = read_kafka_log_level() {
139 config.set_log_level(log_level);
140 }
141 let client: BaseConsumer<RwConsumerContext> =
142 config.create_with_context(client_ctx).await?;
143 if self.inner.is_aws_msk_iam() {
144 #[cfg(not(madsim))]
145 client.poll(Duration::from_secs(10)); #[cfg(madsim)]
147 client.poll(Duration::from_secs(10)).await;
148 }
149 Ok(client)
150 }
151}
152
153#[serde_as]
154#[derive(Debug, Clone, PartialEq, Eq, Deserialize, WithOptions)]
155#[serde(deny_unknown_fields)]
156pub struct IcebergConnection {
157 #[serde(rename = "catalog.type")]
158 pub catalog_type: Option<String>,
159 #[serde(rename = "s3.region")]
160 pub region: Option<String>,
161 #[serde(rename = "s3.endpoint")]
162 pub endpoint: Option<String>,
163 #[serde(rename = "s3.access.key")]
164 pub access_key: Option<String>,
165 #[serde(rename = "s3.secret.key")]
166 pub secret_key: Option<String>,
167
168 #[serde(rename = "gcs.credential")]
169 pub gcs_credential: Option<String>,
170
171 #[serde(rename = "azblob.account_name")]
172 pub azblob_account_name: Option<String>,
173 #[serde(rename = "azblob.account_key")]
174 pub azblob_account_key: Option<String>,
175 #[serde(rename = "azblob.endpoint_url")]
176 pub azblob_endpoint_url: Option<String>,
177
178 #[serde(rename = "warehouse.path")]
180 pub warehouse_path: Option<String>,
181 #[serde(rename = "glue.id")]
184 pub glue_id: Option<String>,
185 #[serde(rename = "catalog.name")]
187 pub catalog_name: Option<String>,
188 #[serde(rename = "catalog.uri")]
190 pub catalog_uri: Option<String>,
191 #[serde(rename = "catalog.credential")]
194 pub credential: Option<String>,
195 #[serde(rename = "catalog.token")]
198 pub token: Option<String>,
199 #[serde(rename = "catalog.oauth2_server_uri")]
202 pub oauth2_server_uri: Option<String>,
203 #[serde(rename = "catalog.scope")]
206 pub scope: Option<String>,
207
208 #[serde(rename = "catalog.rest.signing_region")]
210 pub rest_signing_region: Option<String>,
211
212 #[serde(rename = "catalog.rest.signing_name")]
214 pub rest_signing_name: Option<String>,
215
216 #[serde(
218 rename = "catalog.rest.sigv4_enabled",
219 default,
220 deserialize_with = "deserialize_optional_bool_from_string"
221 )]
222 pub rest_sigv4_enabled: Option<bool>,
223
224 #[serde(
225 rename = "s3.path.style.access",
226 default,
227 deserialize_with = "deserialize_optional_bool_from_string"
228 )]
229 pub path_style_access: Option<bool>,
230
231 #[serde(rename = "catalog.jdbc.user")]
232 pub jdbc_user: Option<String>,
233
234 #[serde(rename = "catalog.jdbc.password")]
235 pub jdbc_password: Option<String>,
236
237 #[serde(default, deserialize_with = "deserialize_optional_bool_from_string")]
239 pub enable_config_load: Option<bool>,
240
241 #[serde(
243 rename = "hosted_catalog",
244 default,
245 deserialize_with = "deserialize_optional_bool_from_string"
246 )]
247 pub hosted_catalog: Option<bool>,
248
249 #[serde(rename = "catalog.header")]
256 pub header: Option<String>,
257}
258
259impl EnforceSecret for IcebergConnection {
260 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
261 "s3.access.key",
262 "s3.secret.key",
263 "gcs.credential",
264 "catalog.token",
265 };
266}
267
268#[async_trait]
269impl Connection for IcebergConnection {
270 async fn validate_connection(&self) -> ConnectorResult<()> {
271 let info = match &self.warehouse_path {
272 Some(warehouse_path) => {
273 let is_s3_tables = warehouse_path.starts_with("arn:aws:s3tables");
274 let url = Url::parse(warehouse_path);
275 if (url.is_err() || is_s3_tables)
276 && matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust"))
277 {
278 None
282 } else {
283 let url =
284 url.with_context(|| format!("Invalid warehouse path: {}", warehouse_path))?;
285 let bucket = url
286 .host_str()
287 .with_context(|| {
288 format!("Invalid s3 path: {}, bucket is missing", warehouse_path)
289 })?
290 .to_owned();
291 let root = url.path().trim_start_matches('/').to_owned();
292 Some((url.scheme().to_owned(), bucket, root))
293 }
294 }
295 None => {
296 if matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust")) {
297 None
298 } else {
299 bail!("`warehouse.path` must be set");
300 }
301 }
302 };
303
304 if let Some((scheme, bucket, root)) = info {
306 match scheme.as_str() {
307 "s3" | "s3a" => {
308 let mut builder = S3::default();
309 if let Some(region) = &self.region {
310 builder = builder.region(region);
311 }
312 if let Some(endpoint) = &self.endpoint {
313 builder = builder.endpoint(endpoint);
314 }
315 if let Some(access_key) = &self.access_key {
316 builder = builder.access_key_id(access_key);
317 }
318 if let Some(secret_key) = &self.secret_key {
319 builder = builder.secret_access_key(secret_key);
320 }
321 builder = builder.root(root.as_str()).bucket(bucket.as_str());
322 let op = Operator::new(builder)?.finish();
323 op.check().await?;
324 }
325 "gs" | "gcs" => {
326 let mut builder = Gcs::default();
327 if let Some(credential) = &self.gcs_credential {
328 builder = builder.credential(credential);
329 }
330 builder = builder.root(root.as_str()).bucket(bucket.as_str());
331 let op = Operator::new(builder)?.finish();
332 op.check().await?;
333 }
334 "azblob" => {
335 let mut builder = Azblob::default();
336 if let Some(account_name) = &self.azblob_account_name {
337 builder = builder.account_name(account_name);
338 }
339 if let Some(azblob_account_key) = &self.azblob_account_key {
340 builder = builder.account_key(azblob_account_key);
341 }
342 if let Some(azblob_endpoint_url) = &self.azblob_endpoint_url {
343 builder = builder.endpoint(azblob_endpoint_url);
344 }
345 builder = builder.root(root.as_str()).container(bucket.as_str());
346 let op = Operator::new(builder)?.finish();
347 op.check().await?;
348 }
349 _ => {
350 bail!("Unsupported scheme: {}", scheme);
351 }
352 }
353 }
354
355 if env_var_is_true(DISABLE_DEFAULT_CREDENTIAL)
356 && matches!(self.enable_config_load, Some(true))
357 {
358 bail!("`enable_config_load` can't be enabled in this environment");
359 }
360
361 if self.hosted_catalog.unwrap_or(false) {
362 if self.catalog_type.is_some() {
364 bail!("`catalog.type` must not be set when `hosted_catalog` is set");
365 }
366 if self.catalog_uri.is_some() {
367 bail!("`catalog.uri` must not be set when `hosted_catalog` is set");
368 }
369 if self.catalog_name.is_some() {
370 bail!("`catalog.name` must not be set when `hosted_catalog` is set");
371 }
372 if self.jdbc_user.is_some() {
373 bail!("`catalog.jdbc.user` must not be set when `hosted_catalog` is set");
374 }
375 if self.jdbc_password.is_some() {
376 bail!("`catalog.jdbc.password` must not be set when `hosted_catalog` is set");
377 }
378 return Ok(());
379 }
380
381 if self.catalog_type.is_none() {
382 bail!("`catalog.type` must be set");
383 }
384
385 let iceberg_common = IcebergCommon {
387 catalog_type: self.catalog_type.clone(),
388 region: self.region.clone(),
389 endpoint: self.endpoint.clone(),
390 access_key: self.access_key.clone(),
391 secret_key: self.secret_key.clone(),
392 gcs_credential: self.gcs_credential.clone(),
393 azblob_account_name: self.azblob_account_name.clone(),
394 azblob_account_key: self.azblob_account_key.clone(),
395 azblob_endpoint_url: self.azblob_endpoint_url.clone(),
396 warehouse_path: self.warehouse_path.clone(),
397 glue_id: self.glue_id.clone(),
398 catalog_name: self.catalog_name.clone(),
399 catalog_uri: self.catalog_uri.clone(),
400 credential: self.credential.clone(),
401
402 token: self.token.clone(),
403 oauth2_server_uri: self.oauth2_server_uri.clone(),
404 scope: self.scope.clone(),
405 rest_signing_region: self.rest_signing_region.clone(),
406 rest_signing_name: self.rest_signing_name.clone(),
407 rest_sigv4_enabled: self.rest_sigv4_enabled,
408 path_style_access: self.path_style_access,
409 database_name: Some("test_database".to_owned()),
410 table_name: "test_table".to_owned(),
411 enable_config_load: self.enable_config_load,
412 hosted_catalog: self.hosted_catalog,
413 header: self.header.clone(),
414 };
415
416 let mut java_map = HashMap::new();
417 if let Some(jdbc_user) = &self.jdbc_user {
418 java_map.insert("jdbc.user".to_owned(), jdbc_user.to_owned());
419 }
420 if let Some(jdbc_password) = &self.jdbc_password {
421 java_map.insert("jdbc.password".to_owned(), jdbc_password.to_owned());
422 }
423 let catalog = iceberg_common.create_catalog(&java_map).await?;
424 catalog
426 .table_exists(&iceberg_common.full_table_name()?)
427 .await?;
428 Ok(())
429 }
430}
431
432#[serde_as]
433#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
434#[serde(deny_unknown_fields)]
435pub struct ConfluentSchemaRegistryConnection {
436 #[serde(rename = "schema.registry")]
437 pub url: String,
438 #[serde(rename = "schema.registry.username")]
440 pub username: Option<String>,
441 #[serde(rename = "schema.registry.password")]
442 pub password: Option<String>,
443}
444
445#[async_trait]
446impl Connection for ConfluentSchemaRegistryConnection {
447 async fn validate_connection(&self) -> ConnectorResult<()> {
448 let client = ConfluentSchemaRegistryClient::try_from(self)?;
450 client.validate_connection().await?;
451 Ok(())
452 }
453}
454
455impl EnforceSecret for ConfluentSchemaRegistryConnection {
456 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
457 "schema.registry.password",
458 };
459}
460
461#[derive(Debug, Clone, Deserialize, PartialEq, Hash, Eq)]
462pub struct ElasticsearchConnection(pub BTreeMap<String, String>);
463
464#[async_trait]
465impl Connection for ElasticsearchConnection {
466 async fn validate_connection(&self) -> ConnectorResult<()> {
467 const CONNECTOR: &str = "elasticsearch";
468
469 let config = ElasticSearchOpenSearchConfig::try_from(self)?;
470 let client = config.build_client(CONNECTOR)?;
471 client.ping().await?;
472 Ok(())
473 }
474}
475
476impl EnforceSecret for ElasticsearchConnection {
477 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
478 "elasticsearch.password",
479 };
480}