risingwave_connector/connector_common/
connection.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::time::Duration;
17
18use anyhow::Context;
19use opendal::Operator;
20use opendal::services::{Azblob, Gcs, S3};
21use phf::{Set, phf_set};
22use rdkafka::ClientConfig;
23use rdkafka::config::RDKafkaLogLevel;
24use rdkafka::consumer::{BaseConsumer, Consumer};
25use risingwave_common::bail;
26use risingwave_common::secret::LocalSecretManager;
27use risingwave_common::util::env_var::env_var_is_true;
28use risingwave_pb::catalog::PbConnection;
29use serde::Deserialize;
30use serde_with::serde_as;
31use tonic::async_trait;
32use url::Url;
33use with_options::WithOptions;
34
35use crate::connector_common::common::DISABLE_DEFAULT_CREDENTIAL;
36use crate::connector_common::{
37    AwsAuthProps, IcebergCommon, IcebergTableIdentifier, KafkaConnectionProps,
38    KafkaPrivateLinkCommon,
39};
40use crate::enforce_secret::EnforceSecret;
41use crate::error::ConnectorResult;
42use crate::schema::schema_registry::Client as ConfluentSchemaRegistryClient;
43use crate::sink::elasticsearch_opensearch::elasticsearch_opensearch_config::ElasticSearchOpenSearchConfig;
44use crate::source::build_connection;
45use crate::source::kafka::{KafkaContextCommon, RwConsumerContext};
46
47pub const SCHEMA_REGISTRY_CONNECTION_TYPE: &str = "schema_registry";
48
49// All XxxConnection structs should implement this trait as well as EnforceSecretOnCloud trait.
50#[async_trait]
51pub trait Connection: Send {
52    async fn validate_connection(&self) -> ConnectorResult<()>;
53}
54
55#[serde_as]
56#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq)]
57#[serde(deny_unknown_fields)]
58pub struct KafkaConnection {
59    #[serde(flatten)]
60    pub inner: KafkaConnectionProps,
61    #[serde(flatten)]
62    pub kafka_private_link_common: KafkaPrivateLinkCommon,
63    #[serde(flatten)]
64    pub aws_auth_props: AwsAuthProps,
65}
66
67impl EnforceSecret for KafkaConnection {
68    fn enforce_secret<'a>(prop_iter: impl Iterator<Item = &'a str>) -> ConnectorResult<()> {
69        for prop in prop_iter {
70            KafkaConnectionProps::enforce_one(prop)?;
71            AwsAuthProps::enforce_one(prop)?;
72        }
73        Ok(())
74    }
75}
76
77pub async fn validate_connection(connection: &PbConnection) -> ConnectorResult<()> {
78    if let Some(ref info) = connection.info {
79        match info {
80            risingwave_pb::catalog::connection::Info::ConnectionParams(cp) => {
81                let options = cp.properties.clone().into_iter().collect();
82                let secret_refs = cp.secret_refs.clone().into_iter().collect();
83                let props_secret_resolved =
84                    LocalSecretManager::global().fill_secrets(options, secret_refs)?;
85                let connection = build_connection(cp.connection_type(), props_secret_resolved)?;
86                connection.validate_connection().await?
87            }
88            #[expect(deprecated)]
89            risingwave_pb::catalog::connection::Info::PrivateLinkService(_) => unreachable!(),
90        }
91    }
92    Ok(())
93}
94
95#[async_trait]
96impl Connection for KafkaConnection {
97    async fn validate_connection(&self) -> ConnectorResult<()> {
98        let client = self.build_client().await?;
99        // describe cluster here
100        client.fetch_metadata(None, Duration::from_secs(10)).await?;
101        Ok(())
102    }
103}
104
105pub fn read_kafka_log_level() -> Option<RDKafkaLogLevel> {
106    let log_level = std::env::var("RISINGWAVE_KAFKA_LOG_LEVEL").ok()?;
107    match log_level.to_uppercase().as_str() {
108        "DEBUG" => Some(RDKafkaLogLevel::Debug),
109        "INFO" => Some(RDKafkaLogLevel::Info),
110        "WARN" => Some(RDKafkaLogLevel::Warning),
111        "ERROR" => Some(RDKafkaLogLevel::Error),
112        "CRITICAL" => Some(RDKafkaLogLevel::Critical),
113        "EMERG" => Some(RDKafkaLogLevel::Emerg),
114        "ALERT" => Some(RDKafkaLogLevel::Alert),
115        "NOTICE" => Some(RDKafkaLogLevel::Notice),
116        _ => None,
117    }
118}
119
120impl KafkaConnection {
121    async fn build_client(&self) -> ConnectorResult<BaseConsumer<RwConsumerContext>> {
122        let mut config = ClientConfig::new();
123        let bootstrap_servers = &self.inner.brokers;
124        let broker_rewrite_map = self.kafka_private_link_common.broker_rewrite_map.clone();
125        config.set("bootstrap.servers", bootstrap_servers);
126        self.inner.set_security_properties(&mut config);
127
128        // dup with Kafka Enumerator
129        let ctx_common = KafkaContextCommon::new(
130            broker_rewrite_map,
131            None,
132            None,
133            self.aws_auth_props.clone(),
134            self.inner.is_aws_msk_iam(),
135        )
136        .await?;
137        let client_ctx = RwConsumerContext::new(ctx_common);
138
139        if let Some(log_level) = read_kafka_log_level() {
140            config.set_log_level(log_level);
141        }
142        let client: BaseConsumer<RwConsumerContext> =
143            config.create_with_context(client_ctx).await?;
144        if self.inner.is_aws_msk_iam() {
145            #[cfg(not(madsim))]
146            client.poll(Duration::from_secs(10)); // note: this is a blocking call
147            #[cfg(madsim)]
148            client.poll(Duration::from_secs(10)).await;
149        }
150        Ok(client)
151    }
152}
153
154#[serde_as]
155#[derive(Debug, Clone, PartialEq, Eq, Deserialize, WithOptions)]
156#[serde(deny_unknown_fields)]
157pub struct IcebergConnection {
158    #[serde(flatten)]
159    pub common: IcebergCommon,
160
161    #[serde(rename = "catalog.jdbc.user")]
162    pub jdbc_user: Option<String>,
163
164    #[serde(rename = "catalog.jdbc.password")]
165    pub jdbc_password: Option<String>,
166}
167
168impl EnforceSecret for IcebergConnection {
169    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
170        "s3.access.key",
171        "s3.secret.key",
172        "gcs.credential",
173        "catalog.token",
174    };
175}
176
177#[async_trait]
178impl Connection for IcebergConnection {
179    async fn validate_connection(&self) -> ConnectorResult<()> {
180        let common = &self.common;
181
182        let info = match &common.warehouse_path {
183            Some(warehouse_path) => {
184                let is_s3_tables = warehouse_path.starts_with("arn:aws:s3tables");
185                let url = Url::parse(warehouse_path);
186                if (url.is_err() || is_s3_tables)
187                    && matches!(common.catalog_type(), "rest" | "rest_rust")
188                {
189                    // If the warehouse path is not a valid URL, it could be a warehouse name in rest catalog,
190                    // Or it could be a s3tables path, which is not a valid URL but a valid warehouse path,
191                    // so we allow it to pass here.
192                    None
193                } else {
194                    let url =
195                        url.with_context(|| format!("Invalid warehouse path: {}", warehouse_path))?;
196                    let bucket = url
197                        .host_str()
198                        .with_context(|| {
199                            format!("Invalid s3 path: {}, bucket is missing", warehouse_path)
200                        })?
201                        .to_owned();
202                    let root = url.path().trim_start_matches('/').to_owned();
203                    Some((url.scheme().to_owned(), bucket, root))
204                }
205            }
206            None => {
207                if matches!(common.catalog_type(), "rest" | "rest_rust") {
208                    None
209                } else {
210                    bail!("`warehouse.path` must be set");
211                }
212            }
213        };
214
215        // Test warehouse
216        if let Some((scheme, bucket, root)) = info {
217            match scheme.as_str() {
218                "s3" | "s3a" => {
219                    let mut builder = S3::default();
220                    if let Some(region) = &common.s3_region {
221                        builder = builder.region(region);
222                    }
223                    if let Some(endpoint) = &common.s3_endpoint {
224                        builder = builder.endpoint(endpoint);
225                    }
226                    if let Some(access_key) = &common.s3_access_key {
227                        builder = builder.access_key_id(access_key);
228                    }
229                    if let Some(secret_key) = &common.s3_secret_key {
230                        builder = builder.secret_access_key(secret_key);
231                    }
232                    builder = builder.root(root.as_str()).bucket(bucket.as_str());
233                    let op = Operator::new(builder)?.finish();
234                    op.check().await?;
235                }
236                "gs" | "gcs" => {
237                    let mut builder = Gcs::default();
238                    if let Some(credential) = &common.gcs_credential {
239                        builder = builder.credential(credential);
240                    }
241                    builder = builder.root(root.as_str()).bucket(bucket.as_str());
242                    let op = Operator::new(builder)?.finish();
243                    op.check().await?;
244                }
245                "azblob" => {
246                    let mut builder = Azblob::default();
247                    if let Some(account_name) = &common.azblob_account_name {
248                        builder = builder.account_name(account_name);
249                    }
250                    if let Some(azblob_account_key) = &common.azblob_account_key {
251                        builder = builder.account_key(azblob_account_key);
252                    }
253                    if let Some(azblob_endpoint_url) = &common.azblob_endpoint_url {
254                        builder = builder.endpoint(azblob_endpoint_url);
255                    }
256                    builder = builder.root(root.as_str()).container(bucket.as_str());
257                    let op = Operator::new(builder)?.finish();
258                    op.check().await?;
259                }
260                _ => {
261                    bail!("Unsupported scheme: {}", scheme);
262                }
263            }
264        }
265
266        if env_var_is_true(DISABLE_DEFAULT_CREDENTIAL)
267            && matches!(common.enable_config_load, Some(true))
268        {
269            bail!("`enable_config_load` can't be enabled in this environment");
270        }
271
272        if common.hosted_catalog.unwrap_or(false) {
273            // If `hosted_catalog` is set, we don't need to test the catalog, but just ensure no catalog fields are set.
274            if common.catalog_type.is_some() {
275                bail!("`catalog.type` must not be set when `hosted_catalog` is set");
276            }
277            if common.catalog_uri.is_some() {
278                bail!("`catalog.uri` must not be set when `hosted_catalog` is set");
279            }
280            if common.catalog_name.is_some() {
281                bail!("`catalog.name` must not be set when `hosted_catalog` is set");
282            }
283            if self.jdbc_user.is_some() {
284                bail!("`catalog.jdbc.user` must not be set when `hosted_catalog` is set");
285            }
286            if self.jdbc_password.is_some() {
287                bail!("`catalog.jdbc.password` must not be set when `hosted_catalog` is set");
288            }
289            return Ok(());
290        }
291
292        if common.catalog_type.is_none() {
293            bail!("`catalog.type` must be set");
294        }
295
296        // Test catalog
297        let iceberg_common = common.clone();
298
299        let mut java_map = HashMap::new();
300        if let Some(jdbc_user) = &self.jdbc_user {
301            java_map.insert("jdbc.user".to_owned(), jdbc_user.to_owned());
302        }
303        if let Some(jdbc_password) = &self.jdbc_password {
304            java_map.insert("jdbc.password".to_owned(), jdbc_password.to_owned());
305        }
306        let catalog = iceberg_common.create_catalog(&java_map).await?;
307        // test catalog by `table_exists` api
308        let test_table_ident = IcebergTableIdentifier {
309            database_name: Some("test_database".to_owned()),
310            table_name: "test_table".to_owned(),
311        }
312        .to_table_ident()?;
313        catalog.table_exists(&test_table_ident).await?;
314        Ok(())
315    }
316}
317
318#[serde_as]
319#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
320#[serde(deny_unknown_fields)]
321pub struct ConfluentSchemaRegistryConnection {
322    #[serde(rename = "schema.registry")]
323    pub url: String,
324    // ref `SchemaRegistryAuth`
325    #[serde(rename = "schema.registry.username")]
326    pub username: Option<String>,
327    #[serde(rename = "schema.registry.password")]
328    pub password: Option<String>,
329}
330
331#[async_trait]
332impl Connection for ConfluentSchemaRegistryConnection {
333    async fn validate_connection(&self) -> ConnectorResult<()> {
334        // GET /config to validate the connection
335        let client = ConfluentSchemaRegistryClient::try_from(self)?;
336        client.validate_connection().await?;
337        Ok(())
338    }
339}
340
341impl EnforceSecret for ConfluentSchemaRegistryConnection {
342    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
343        "schema.registry.password",
344    };
345}
346
347#[derive(Debug, Clone, Deserialize, PartialEq, Hash, Eq)]
348pub struct ElasticsearchConnection(pub BTreeMap<String, String>);
349
350#[async_trait]
351impl Connection for ElasticsearchConnection {
352    async fn validate_connection(&self) -> ConnectorResult<()> {
353        const CONNECTOR: &str = "elasticsearch";
354
355        let config = ElasticSearchOpenSearchConfig::try_from(self)?;
356        let client = config.build_client(CONNECTOR)?;
357        client.ping().await?;
358        Ok(())
359    }
360}
361
362impl EnforceSecret for ElasticsearchConnection {
363    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
364        "elasticsearch.password",
365    };
366}