risingwave_connector/connector_common/
connection.rs1use std::collections::{BTreeMap, HashMap};
16use std::time::Duration;
17
18use anyhow::Context;
19use opendal::Operator;
20use opendal::services::{Gcs, S3};
21use rdkafka::ClientConfig;
22use rdkafka::consumer::{BaseConsumer, Consumer};
23use risingwave_common::bail;
24use risingwave_common::secret::LocalSecretManager;
25use risingwave_pb::catalog::PbConnection;
26use serde_derive::Deserialize;
27use serde_with::serde_as;
28use tonic::async_trait;
29use url::Url;
30use with_options::WithOptions;
31
32use crate::connector_common::{
33 AwsAuthProps, IcebergCommon, KafkaConnectionProps, KafkaPrivateLinkCommon,
34};
35use crate::deserialize_optional_bool_from_string;
36use crate::error::ConnectorResult;
37use crate::schema::schema_registry::Client as ConfluentSchemaRegistryClient;
38use crate::sink::elasticsearch_opensearch::elasticsearch_opensearch_config::ElasticSearchOpenSearchConfig;
39use crate::source::build_connection;
40use crate::source::kafka::{KafkaContextCommon, RwConsumerContext};
41
42pub const SCHEMA_REGISTRY_CONNECTION_TYPE: &str = "schema_registry";
43
44#[async_trait]
45pub trait Connection: Send {
46 async fn validate_connection(&self) -> ConnectorResult<()>;
47}
48
49#[serde_as]
50#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq)]
51#[serde(deny_unknown_fields)]
52pub struct KafkaConnection {
53 #[serde(flatten)]
54 pub inner: KafkaConnectionProps,
55 #[serde(flatten)]
56 pub kafka_private_link_common: KafkaPrivateLinkCommon,
57 #[serde(flatten)]
58 pub aws_auth_props: AwsAuthProps,
59}
60
61pub async fn validate_connection(connection: &PbConnection) -> ConnectorResult<()> {
62 if let Some(ref info) = connection.info {
63 match info {
64 risingwave_pb::catalog::connection::Info::ConnectionParams(cp) => {
65 let options = cp.properties.clone().into_iter().collect();
66 let secret_refs = cp.secret_refs.clone().into_iter().collect();
67 let props_secret_resolved =
68 LocalSecretManager::global().fill_secrets(options, secret_refs)?;
69 let connection = build_connection(cp.connection_type(), props_secret_resolved)?;
70 connection.validate_connection().await?
71 }
72 risingwave_pb::catalog::connection::Info::PrivateLinkService(_) => unreachable!(),
73 }
74 }
75 Ok(())
76}
77
78#[async_trait]
79impl Connection for KafkaConnection {
80 async fn validate_connection(&self) -> ConnectorResult<()> {
81 let client = self.build_client().await?;
82 client.fetch_metadata(None, Duration::from_secs(10)).await?;
84 Ok(())
85 }
86}
87
88impl KafkaConnection {
89 async fn build_client(&self) -> ConnectorResult<BaseConsumer<RwConsumerContext>> {
90 let mut config = ClientConfig::new();
91 let bootstrap_servers = &self.inner.brokers;
92 let broker_rewrite_map = self.kafka_private_link_common.broker_rewrite_map.clone();
93 config.set("bootstrap.servers", bootstrap_servers);
94 self.inner.set_security_properties(&mut config);
95
96 let ctx_common = KafkaContextCommon::new(
98 broker_rewrite_map,
99 None,
100 None,
101 self.aws_auth_props.clone(),
102 self.inner.is_aws_msk_iam(),
103 )
104 .await?;
105 let client_ctx = RwConsumerContext::new(ctx_common);
106 let client: BaseConsumer<RwConsumerContext> =
107 config.create_with_context(client_ctx).await?;
108 if self.inner.is_aws_msk_iam() {
109 #[cfg(not(madsim))]
110 client.poll(Duration::from_secs(10)); #[cfg(madsim)]
112 client.poll(Duration::from_secs(10)).await;
113 }
114 Ok(client)
115 }
116}
117
118#[serde_as]
119#[derive(Debug, Clone, PartialEq, Eq, Deserialize, WithOptions)]
120#[serde(deny_unknown_fields)]
121pub struct IcebergConnection {
122 #[serde(rename = "catalog.type")]
123 pub catalog_type: Option<String>,
124 #[serde(rename = "s3.region")]
125 pub region: Option<String>,
126 #[serde(rename = "s3.endpoint")]
127 pub endpoint: Option<String>,
128 #[serde(rename = "s3.access.key")]
129 pub access_key: Option<String>,
130 #[serde(rename = "s3.secret.key")]
131 pub secret_key: Option<String>,
132
133 #[serde(rename = "gcs.credential")]
134 pub gcs_credential: Option<String>,
135
136 #[serde(rename = "warehouse.path")]
138 pub warehouse_path: Option<String>,
139 #[serde(rename = "glue.id")]
142 pub glue_id: Option<String>,
143 #[serde(rename = "catalog.name")]
145 pub catalog_name: Option<String>,
146 #[serde(rename = "catalog.uri")]
148 pub catalog_uri: Option<String>,
149 #[serde(rename = "catalog.credential")]
152 pub credential: Option<String>,
153 #[serde(rename = "catalog.token")]
156 pub token: Option<String>,
157 #[serde(rename = "catalog.oauth2_server_uri")]
160 pub oauth2_server_uri: Option<String>,
161 #[serde(rename = "catalog.scope")]
164 pub scope: Option<String>,
165
166 #[serde(rename = "catalog.rest.signing_region")]
168 pub rest_signing_region: Option<String>,
169
170 #[serde(rename = "catalog.rest.signing_name")]
172 pub rest_signing_name: Option<String>,
173
174 #[serde(
176 rename = "catalog.rest.sigv4_enabled",
177 default,
178 deserialize_with = "deserialize_optional_bool_from_string"
179 )]
180 pub rest_sigv4_enabled: Option<bool>,
181
182 #[serde(
183 rename = "s3.path.style.access",
184 default,
185 deserialize_with = "deserialize_optional_bool_from_string"
186 )]
187 pub path_style_access: Option<bool>,
188
189 #[serde(rename = "catalog.jdbc.user")]
190 pub jdbc_user: Option<String>,
191
192 #[serde(rename = "catalog.jdbc.password")]
193 pub jdbc_password: Option<String>,
194}
195
196#[async_trait]
197impl Connection for IcebergConnection {
198 async fn validate_connection(&self) -> ConnectorResult<()> {
199 let info = match &self.warehouse_path {
200 Some(warehouse_path) => {
201 let is_s3_tables = warehouse_path.starts_with("arn:aws:s3tables");
202 let url = Url::parse(warehouse_path);
203 if (url.is_err() || is_s3_tables)
204 && matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust"))
205 {
206 None
210 } else {
211 let url =
212 url.with_context(|| format!("Invalid warehouse path: {}", warehouse_path))?;
213 let bucket = url
214 .host_str()
215 .with_context(|| {
216 format!("Invalid s3 path: {}, bucket is missing", warehouse_path)
217 })?
218 .to_owned();
219 let root = url.path().trim_start_matches('/').to_owned();
220 Some((url.scheme().to_owned(), bucket, root))
221 }
222 }
223 None => {
224 if matches!(self.catalog_type.as_deref(), Some("rest" | "rest_rust")) {
225 None
226 } else {
227 bail!("`warehouse.path` must be set");
228 }
229 }
230 };
231
232 if let Some((scheme, bucket, root)) = info {
234 match scheme.as_str() {
235 "s3" | "s3a" => {
236 let mut builder = S3::default();
237 if let Some(region) = &self.region {
238 builder = builder.region(region);
239 }
240 if let Some(endpoint) = &self.endpoint {
241 builder = builder.endpoint(endpoint);
242 }
243 if let Some(access_key) = &self.access_key {
244 builder = builder.access_key_id(access_key);
245 }
246 if let Some(secret_key) = &self.secret_key {
247 builder = builder.secret_access_key(secret_key);
248 }
249 builder = builder.root(root.as_str()).bucket(bucket.as_str());
250 let op = Operator::new(builder)?.finish();
251 op.check().await?;
252 }
253 "gs" | "gcs" => {
254 let mut builder = Gcs::default();
255 if let Some(credential) = &self.gcs_credential {
256 builder = builder.credential(credential);
257 }
258 builder = builder.root(root.as_str()).bucket(bucket.as_str());
259 let op = Operator::new(builder)?.finish();
260 op.check().await?;
261 }
262 _ => {
263 bail!("Unsupported scheme: {}", scheme);
264 }
265 }
266 }
267
268 if self.catalog_type.is_none() {
269 bail!("`catalog.type` must be set");
270 }
271
272 let iceberg_common = IcebergCommon {
274 catalog_type: self.catalog_type.clone(),
275 region: self.region.clone(),
276 endpoint: self.endpoint.clone(),
277 access_key: self.access_key.clone(),
278 secret_key: self.secret_key.clone(),
279 gcs_credential: self.gcs_credential.clone(),
280 warehouse_path: self.warehouse_path.clone(),
281 glue_id: self.glue_id.clone(),
282 catalog_name: self.catalog_name.clone(),
283 catalog_uri: self.catalog_uri.clone(),
284 credential: self.credential.clone(),
285 token: self.token.clone(),
286 oauth2_server_uri: self.oauth2_server_uri.clone(),
287 scope: self.scope.clone(),
288 rest_signing_region: self.rest_signing_region.clone(),
289 rest_signing_name: self.rest_signing_name.clone(),
290 rest_sigv4_enabled: self.rest_sigv4_enabled,
291 path_style_access: self.path_style_access,
292 database_name: Some("test_database".to_owned()),
293 table_name: "test_table".to_owned(),
294 enable_config_load: Some(false),
295 };
296
297 let mut java_map = HashMap::new();
298 if let Some(jdbc_user) = &self.jdbc_user {
299 java_map.insert("jdbc.user".to_owned(), jdbc_user.to_owned());
300 }
301 if let Some(jdbc_password) = &self.jdbc_password {
302 java_map.insert("jdbc.password".to_owned(), jdbc_password.to_owned());
303 }
304 let catalog = iceberg_common.create_catalog(&java_map).await?;
305 catalog
307 .table_exists(&iceberg_common.full_table_name()?)
308 .await?;
309 Ok(())
310 }
311}
312
313#[serde_as]
314#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
315#[serde(deny_unknown_fields)]
316pub struct ConfluentSchemaRegistryConnection {
317 #[serde(rename = "schema.registry")]
318 pub url: String,
319 #[serde(rename = "schema.registry.username")]
321 pub username: Option<String>,
322 #[serde(rename = "schema.registry.password")]
323 pub password: Option<String>,
324}
325
326#[async_trait]
327impl Connection for ConfluentSchemaRegistryConnection {
328 async fn validate_connection(&self) -> ConnectorResult<()> {
329 let client = ConfluentSchemaRegistryClient::try_from(self)?;
331 client.validate_connection().await?;
332 Ok(())
333 }
334}
335
336#[derive(Debug, Clone, Deserialize, PartialEq, Hash, Eq)]
337pub struct ElasticsearchConnection(pub BTreeMap<String, String>);
338
339#[async_trait]
340impl Connection for ElasticsearchConnection {
341 async fn validate_connection(&self) -> ConnectorResult<()> {
342 const CONNECTOR: &str = "elasticsearch";
343
344 let config = ElasticSearchOpenSearchConfig::try_from(self)?;
345 let client = config.build_client(CONNECTOR)?;
346 client.ping().await?;
347 Ok(())
348 }
349}