Skip to main content

risingwave_connector/connector_common/
common.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::hash::Hash;
17use std::io::Write;
18use std::path::Path;
19use std::sync::{Arc, LazyLock, Weak};
20use std::time::Duration;
21
22use anyhow::{Context, anyhow};
23use async_nats::jetstream::consumer::DeliverPolicy;
24use async_nats::jetstream::{self};
25use aws_sdk_kinesis::Client as KinesisClient;
26use aws_sdk_kinesis::config::{AsyncSleep, SharedAsyncSleep, Sleep};
27use moka::future::Cache as MokaCache;
28use moka::ops::compute::Op;
29use phf::{Set, phf_set};
30use pulsar::authentication::oauth2::{OAuth2Authentication, OAuth2Params};
31use pulsar::{Authentication, OperationRetryOptions, Pulsar, TokioExecutor};
32use rdkafka::ClientConfig;
33use risingwave_common::bail;
34use rustls_pki_types::pem::PemObject;
35use rustls_pki_types::{CertificateDer, PrivatePkcs8KeyDer};
36use serde::Deserialize;
37use serde_with::json::JsonString;
38use serde_with::{DisplayFromStr, serde_as};
39use tempfile::NamedTempFile;
40use time::OffsetDateTime;
41use url::Url;
42use with_options::WithOptions;
43
44use crate::aws_utils::load_file_descriptor_from_s3;
45use crate::deserialize_duration_from_string;
46use crate::enforce_secret::EnforceSecret;
47use crate::error::ConnectorResult;
48use crate::sink::SinkError;
49use crate::source::nats::source::NatsOffset;
50
51pub const PRIVATE_LINK_BROKER_REWRITE_MAP_KEY: &str = "broker.rewrite.endpoints";
52pub const PRIVATE_LINK_TARGETS_KEY: &str = "privatelink.targets";
53
54const AWS_MSK_IAM_AUTH: &str = "AWS_MSK_IAM";
55
56/// The environment variable to disable using default credential from environment.
57/// It's recommended to set this variable to `true` in cloud hosting environment.
58pub const DISABLE_DEFAULT_CREDENTIAL: &str = "DISABLE_DEFAULT_CREDENTIAL";
59
60#[derive(Debug, Clone, Deserialize)]
61pub struct AwsPrivateLinkItem {
62    pub az_id: Option<String>,
63    pub port: u16,
64}
65
66use aws_config::default_provider::region::DefaultRegionChain;
67use aws_config::sts::AssumeRoleProvider;
68use aws_credential_types::provider::SharedCredentialsProvider;
69use aws_types::SdkConfig;
70use aws_types::region::Region;
71use risingwave_common::util::env_var::env_var_is_true;
72
73/// A flatten config map for aws auth.
74#[derive(Deserialize, Debug, Clone, WithOptions, PartialEq)]
75pub struct AwsAuthProps {
76    #[serde(rename = "aws.region", alias = "region", alias = "s3.region")]
77    pub region: Option<String>,
78
79    #[serde(
80        rename = "aws.endpoint_url",
81        alias = "endpoint_url",
82        alias = "endpoint",
83        alias = "s3.endpoint"
84    )]
85    pub endpoint: Option<String>,
86    #[serde(
87        rename = "aws.credentials.access_key_id",
88        alias = "access_key",
89        alias = "s3.access.key"
90    )]
91    pub access_key: Option<String>,
92    #[serde(
93        rename = "aws.credentials.secret_access_key",
94        alias = "secret_key",
95        alias = "s3.secret.key"
96    )]
97    pub secret_key: Option<String>,
98    #[serde(rename = "aws.credentials.session_token", alias = "session_token")]
99    pub session_token: Option<String>,
100    /// IAM role
101    #[serde(rename = "aws.credentials.role.arn", alias = "arn")]
102    pub arn: Option<String>,
103    /// external ID in IAM role trust policy
104    #[serde(rename = "aws.credentials.role.external_id", alias = "external_id")]
105    pub external_id: Option<String>,
106    #[serde(rename = "aws.profile", alias = "profile")]
107    pub profile: Option<String>,
108    #[serde(rename = "aws.msk.signer_timeout_sec")]
109    pub msk_signer_timeout_sec: Option<u64>,
110}
111
112impl EnforceSecret for AwsAuthProps {
113    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
114        "access_key",
115        "aws.credentials.access_key_id",
116        "s3.access.key",
117        "secret_key",
118        "aws.credentials.secret_access_key",
119        "s3.secret.key",
120        "session_token",
121        "aws.credentials.session_token",
122    };
123}
124
125impl AwsAuthProps {
126    async fn build_region(&self) -> ConnectorResult<Region> {
127        if let Some(region_name) = &self.region {
128            Ok(Region::new(region_name.clone()))
129        } else {
130            let mut region_chain = DefaultRegionChain::builder();
131            if let Some(profile_name) = &self.profile {
132                region_chain = region_chain.profile_name(profile_name);
133            }
134
135            Ok(region_chain
136                .build()
137                .region()
138                .await
139                .context("region should be provided")?)
140        }
141    }
142
143    async fn build_credential_provider(&self) -> ConnectorResult<SharedCredentialsProvider> {
144        if let (Some(access_key), Some(secret_key)) =
145            (self.access_key.as_ref(), self.secret_key.as_ref())
146        {
147            Ok(SharedCredentialsProvider::new(
148                aws_credential_types::Credentials::from_keys(
149                    access_key,
150                    secret_key,
151                    self.session_token.clone(),
152                ),
153            ))
154        } else if !env_var_is_true(DISABLE_DEFAULT_CREDENTIAL) {
155            Ok(SharedCredentialsProvider::new(
156                aws_config::default_provider::credentials::default_provider().await,
157            ))
158        } else {
159            bail!("Both \"access_key\" and \"secret_key\" are required.")
160        }
161    }
162
163    async fn with_role_provider(
164        &self,
165        credential: SharedCredentialsProvider,
166    ) -> ConnectorResult<SharedCredentialsProvider> {
167        if let Some(role_name) = &self.arn {
168            let region = self.build_region().await?;
169            let mut role = AssumeRoleProvider::builder(role_name)
170                .session_name("RisingWave")
171                .region(region);
172            if let Some(id) = &self.external_id {
173                role = role.external_id(id);
174            }
175            let provider = role.build_from_provider(credential).await;
176            Ok(SharedCredentialsProvider::new(provider))
177        } else {
178            Ok(credential)
179        }
180    }
181
182    pub async fn build_config(&self) -> ConnectorResult<SdkConfig> {
183        let region = self.build_region().await?;
184        let credentials_provider = self
185            .with_role_provider(self.build_credential_provider().await?)
186            .await?;
187        let mut config_loader = aws_config::from_env()
188            .region(region)
189            .credentials_provider(credentials_provider);
190
191        if let Some(endpoint) = self.endpoint.as_ref() {
192            config_loader = config_loader.endpoint_url(endpoint);
193        }
194
195        Ok(config_loader.load().await)
196    }
197}
198
199#[serde_as]
200#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
201pub struct KafkaConnectionProps {
202    #[serde(rename = "properties.bootstrap.server", alias = "kafka.brokers")]
203    pub brokers: String,
204
205    /// Security protocol used for RisingWave to communicate with Kafka brokers. Could be
206    /// PLAINTEXT, SSL, `SASL_PLAINTEXT` or `SASL_SSL`.
207    #[serde(rename = "properties.security.protocol")]
208    #[with_option(allow_alter_on_fly)]
209    security_protocol: Option<String>,
210
211    #[serde(rename = "properties.ssl.endpoint.identification.algorithm")]
212    #[with_option(allow_alter_on_fly)]
213    ssl_endpoint_identification_algorithm: Option<String>,
214
215    // For the properties below, please refer to [librdkafka](https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md) for more information.
216    /// Path to CA certificate file for verifying the broker's key.
217    #[serde(rename = "properties.ssl.ca.location")]
218    #[with_option(allow_alter_on_fly)]
219    ssl_ca_location: Option<String>,
220
221    /// CA certificate string (PEM format) for verifying the broker's key.
222    #[serde(rename = "properties.ssl.ca.pem")]
223    #[with_option(allow_alter_on_fly)]
224    ssl_ca_pem: Option<String>,
225
226    /// Path to client's certificate file (PEM).
227    #[serde(rename = "properties.ssl.certificate.location")]
228    #[with_option(allow_alter_on_fly)]
229    ssl_certificate_location: Option<String>,
230
231    /// Client's public key string (PEM format) used for authentication.
232    #[serde(rename = "properties.ssl.certificate.pem")]
233    #[with_option(allow_alter_on_fly)]
234    ssl_certificate_pem: Option<String>,
235
236    /// Path to client's private key file (PEM).
237    #[serde(rename = "properties.ssl.key.location")]
238    #[with_option(allow_alter_on_fly)]
239    ssl_key_location: Option<String>,
240
241    /// Client's private key string (PEM format) used for authentication.
242    #[serde(rename = "properties.ssl.key.pem")]
243    #[with_option(allow_alter_on_fly)]
244    ssl_key_pem: Option<String>,
245
246    /// Passphrase of client's private key.
247    #[serde(rename = "properties.ssl.key.password")]
248    #[with_option(allow_alter_on_fly)]
249    ssl_key_password: Option<String>,
250
251    /// SASL mechanism if SASL is enabled. Currently support PLAIN, SCRAM, GSSAPI, and `AWS_MSK_IAM`.
252    #[serde(rename = "properties.sasl.mechanism")]
253    #[with_option(allow_alter_on_fly)]
254    sasl_mechanism: Option<String>,
255
256    /// SASL username for SASL/PLAIN and SASL/SCRAM.
257    #[serde(rename = "properties.sasl.username")]
258    #[with_option(allow_alter_on_fly)]
259    sasl_username: Option<String>,
260
261    /// SASL password for SASL/PLAIN and SASL/SCRAM.
262    #[serde(rename = "properties.sasl.password")]
263    #[with_option(allow_alter_on_fly)]
264    sasl_password: Option<String>,
265
266    /// Kafka server's Kerberos principal name under SASL/GSSAPI, not including /hostname@REALM.
267    #[serde(rename = "properties.sasl.kerberos.service.name")]
268    sasl_kerberos_service_name: Option<String>,
269
270    /// Path to client's Kerberos keytab file under SASL/GSSAPI.
271    #[serde(rename = "properties.sasl.kerberos.keytab")]
272    sasl_kerberos_keytab: Option<String>,
273
274    /// Client's Kerberos principal name under SASL/GSSAPI.
275    #[serde(rename = "properties.sasl.kerberos.principal")]
276    sasl_kerberos_principal: Option<String>,
277
278    /// Shell command to refresh or acquire the client's Kerberos ticket under SASL/GSSAPI.
279    #[serde(rename = "properties.sasl.kerberos.kinit.cmd")]
280    sasl_kerberos_kinit_cmd: Option<String>,
281
282    /// Minimum time in milliseconds between key refresh attempts under SASL/GSSAPI.
283    #[serde(rename = "properties.sasl.kerberos.min.time.before.relogin")]
284    sasl_kerberos_min_time_before_relogin: Option<String>,
285
286    /// Configurations for SASL/OAUTHBEARER.
287    #[serde(rename = "properties.sasl.oauthbearer.config")]
288    sasl_oathbearer_config: Option<String>,
289}
290
291impl EnforceSecret for KafkaConnectionProps {
292    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
293        "properties.ssl.key.pem",
294        "properties.ssl.key.password",
295        "properties.sasl.password",
296    };
297}
298
299#[serde_as]
300#[derive(Debug, Clone, Deserialize, WithOptions)]
301pub struct KafkaCommon {
302    // connection related props are moved to `KafkaConnection`
303    #[serde(rename = "topic", alias = "kafka.topic")]
304    pub topic: String,
305
306    #[serde(
307        rename = "properties.sync.call.timeout",
308        deserialize_with = "deserialize_duration_from_string",
309        default = "default_kafka_sync_call_timeout"
310    )]
311    #[with_option(allow_alter_on_fly)]
312    pub sync_call_timeout: Duration,
313}
314
315#[serde_as]
316#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
317pub struct KafkaPrivateLinkCommon {
318    /// This is generated from `private_link_targets` and `private_link_endpoint` in frontend, instead of given by users.
319    #[serde(rename = "broker.rewrite.endpoints")]
320    #[serde_as(as = "Option<JsonString>")]
321    pub broker_rewrite_map: Option<BTreeMap<String, String>>,
322}
323
324const fn default_kafka_sync_call_timeout() -> Duration {
325    Duration::from_secs(5)
326}
327
328const fn default_socket_keepalive_enable() -> bool {
329    true
330}
331
332#[serde_as]
333#[derive(Debug, Clone, Deserialize, WithOptions)]
334pub struct RdKafkaPropertiesCommon {
335    /// Maximum Kafka protocol request message size. Due to differing framing overhead between
336    /// protocol versions the producer is unable to reliably enforce a strict max message limit at
337    /// produce time and may exceed the maximum size by one message in protocol `ProduceRequests`,
338    /// the broker will enforce the topic's max.message.bytes limit
339    #[serde(rename = "properties.message.max.bytes")]
340    #[serde_as(as = "Option<DisplayFromStr>")]
341    #[with_option(allow_alter_on_fly)]
342    pub message_max_bytes: Option<usize>,
343
344    /// Maximum Kafka protocol response message size. This serves as a safety precaution to avoid
345    /// memory exhaustion in case of protocol hickups. This value must be at least fetch.max.bytes
346    /// + 512 to allow for protocol overhead; the value is adjusted automatically unless the
347    /// configuration property is explicitly set.
348    #[serde(rename = "properties.receive.message.max.bytes")]
349    #[serde_as(as = "Option<DisplayFromStr>")]
350    #[with_option(allow_alter_on_fly)]
351    pub receive_message_max_bytes: Option<usize>,
352
353    #[serde(rename = "properties.statistics.interval.ms")]
354    #[serde_as(as = "Option<DisplayFromStr>")]
355    #[with_option(allow_alter_on_fly)]
356    pub statistics_interval_ms: Option<usize>,
357
358    /// Client identifier
359    #[serde(rename = "properties.client.id")]
360    #[serde_as(as = "Option<DisplayFromStr>")]
361    #[with_option(allow_alter_on_fly)]
362    pub client_id: Option<String>,
363
364    #[serde(rename = "properties.enable.ssl.certificate.verification")]
365    #[serde_as(as = "Option<DisplayFromStr>")]
366    #[with_option(allow_alter_on_fly)]
367    pub enable_ssl_certificate_verification: Option<bool>,
368
369    #[serde(
370        rename = "properties.socket.keepalive.enable",
371        default = "default_socket_keepalive_enable"
372    )]
373    #[serde_as(as = "DisplayFromStr")]
374    pub socket_keepalive_enable: bool,
375}
376
377impl RdKafkaPropertiesCommon {
378    pub(crate) fn set_client(&self, c: &mut rdkafka::ClientConfig) {
379        if let Some(v) = self.statistics_interval_ms {
380            c.set("statistics.interval.ms", v.to_string());
381        }
382        if let Some(v) = self.message_max_bytes {
383            c.set("message.max.bytes", v.to_string());
384        }
385        if let Some(v) = self.receive_message_max_bytes {
386            c.set("receive.message.max.bytes", v.to_string());
387        }
388        if let Some(v) = self.client_id.as_ref() {
389            c.set("client.id", v);
390        }
391        if let Some(v) = self.enable_ssl_certificate_verification {
392            c.set("enable.ssl.certificate.verification", v.to_string());
393        }
394        c.set(
395            "socket.keepalive.enable",
396            self.socket_keepalive_enable.to_string(),
397        );
398    }
399}
400
401impl KafkaConnectionProps {
402    #[cfg(test)]
403    pub fn test_default() -> Self {
404        Self {
405            brokers: "localhost:9092".to_owned(),
406            security_protocol: None,
407            ssl_ca_location: None,
408            ssl_certificate_location: None,
409            ssl_key_location: None,
410            ssl_ca_pem: None,
411            ssl_certificate_pem: None,
412            ssl_key_pem: None,
413            ssl_key_password: None,
414            ssl_endpoint_identification_algorithm: None,
415            sasl_mechanism: None,
416            sasl_username: None,
417            sasl_password: None,
418            sasl_kerberos_service_name: None,
419            sasl_kerberos_keytab: None,
420            sasl_kerberos_principal: None,
421            sasl_kerberos_kinit_cmd: None,
422            sasl_kerberos_min_time_before_relogin: None,
423            sasl_oathbearer_config: None,
424        }
425    }
426
427    pub(crate) fn set_security_properties(&self, config: &mut ClientConfig) {
428        // AWS_MSK_IAM
429        if self.is_aws_msk_iam() {
430            config.set("security.protocol", "SASL_SSL");
431            config.set("sasl.mechanism", "OAUTHBEARER");
432            return;
433        }
434
435        // Security protocol
436        if let Some(security_protocol) = self.security_protocol.as_ref() {
437            config.set("security.protocol", security_protocol);
438        }
439
440        // SSL
441        if let Some(ssl_ca_location) = self.ssl_ca_location.as_ref() {
442            config.set("ssl.ca.location", ssl_ca_location);
443        }
444        if let Some(ssl_ca_pem) = self.ssl_ca_pem.as_ref() {
445            config.set("ssl.ca.pem", ssl_ca_pem);
446        }
447        if let Some(ssl_certificate_location) = self.ssl_certificate_location.as_ref() {
448            config.set("ssl.certificate.location", ssl_certificate_location);
449        }
450        if let Some(ssl_certificate_pem) = self.ssl_certificate_pem.as_ref() {
451            config.set("ssl.certificate.pem", ssl_certificate_pem);
452        }
453        if let Some(ssl_key_location) = self.ssl_key_location.as_ref() {
454            config.set("ssl.key.location", ssl_key_location);
455        }
456        if let Some(ssl_key_pem) = self.ssl_key_pem.as_ref() {
457            config.set("ssl.key.pem", ssl_key_pem);
458        }
459        if let Some(ssl_key_password) = self.ssl_key_password.as_ref() {
460            config.set("ssl.key.password", ssl_key_password);
461        }
462        if let Some(ssl_endpoint_identification_algorithm) =
463            self.ssl_endpoint_identification_algorithm.as_ref()
464        {
465            // accept only `none` and `http` here, let the sdk do the check
466            config.set(
467                "ssl.endpoint.identification.algorithm",
468                ssl_endpoint_identification_algorithm,
469            );
470        }
471
472        // SASL mechanism
473        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref() {
474            config.set("sasl.mechanism", sasl_mechanism);
475        }
476
477        // SASL/PLAIN & SASL/SCRAM
478        if let Some(sasl_username) = self.sasl_username.as_ref() {
479            config.set("sasl.username", sasl_username);
480        }
481        if let Some(sasl_password) = self.sasl_password.as_ref() {
482            config.set("sasl.password", sasl_password);
483        }
484
485        // SASL/GSSAPI
486        if let Some(sasl_kerberos_service_name) = self.sasl_kerberos_service_name.as_ref() {
487            config.set("sasl.kerberos.service.name", sasl_kerberos_service_name);
488        }
489        if let Some(sasl_kerberos_keytab) = self.sasl_kerberos_keytab.as_ref() {
490            config.set("sasl.kerberos.keytab", sasl_kerberos_keytab);
491        }
492        if let Some(sasl_kerberos_principal) = self.sasl_kerberos_principal.as_ref() {
493            config.set("sasl.kerberos.principal", sasl_kerberos_principal);
494        }
495        if let Some(sasl_kerberos_kinit_cmd) = self.sasl_kerberos_kinit_cmd.as_ref() {
496            config.set("sasl.kerberos.kinit.cmd", sasl_kerberos_kinit_cmd);
497        }
498        if let Some(sasl_kerberos_min_time_before_relogin) =
499            self.sasl_kerberos_min_time_before_relogin.as_ref()
500        {
501            config.set(
502                "sasl.kerberos.min.time.before.relogin",
503                sasl_kerberos_min_time_before_relogin,
504            );
505        }
506
507        // SASL/OAUTHBEARER
508        if let Some(sasl_oathbearer_config) = self.sasl_oathbearer_config.as_ref() {
509            config.set("sasl.oauthbearer.config", sasl_oathbearer_config);
510        }
511        // Currently, we only support unsecured OAUTH.
512        config.set("enable.sasl.oauthbearer.unsecure.jwt", "true");
513    }
514
515    pub(crate) fn is_aws_msk_iam(&self) -> bool {
516        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref()
517            && sasl_mechanism == AWS_MSK_IAM_AUTH
518        {
519            true
520        } else {
521            false
522        }
523    }
524}
525
526#[derive(Clone, Debug, Deserialize, WithOptions)]
527pub struct PulsarCommon {
528    #[serde(rename = "topic", alias = "pulsar.topic")]
529    pub topic: String,
530
531    #[serde(rename = "service.url", alias = "pulsar.service.url")]
532    pub service_url: String,
533
534    #[serde(rename = "auth.token")]
535    pub auth_token: Option<String>,
536}
537
538impl EnforceSecret for PulsarCommon {
539    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
540        "pulsar.auth.token",
541    };
542}
543
544#[derive(Clone, Debug, Deserialize, WithOptions)]
545pub struct PulsarOauthCommon {
546    #[serde(rename = "oauth.issuer.url")]
547    pub issuer_url: String,
548
549    #[serde(rename = "oauth.credentials.url")]
550    pub credentials_url: String,
551
552    #[serde(rename = "oauth.audience")]
553    pub audience: String,
554
555    #[serde(rename = "oauth.scope")]
556    pub scope: Option<String>,
557}
558
559fn create_credential_temp_file(credentials: &[u8]) -> std::io::Result<NamedTempFile> {
560    let mut f = NamedTempFile::new()?;
561    f.write_all(credentials)?;
562    f.as_file().sync_all()?;
563    Ok(f)
564}
565
566impl PulsarCommon {
567    pub(crate) async fn build_client(
568        &self,
569        oauth: &Option<PulsarOauthCommon>,
570        aws_auth_props: &AwsAuthProps,
571        operation_retry_options: Option<OperationRetryOptions>,
572    ) -> ConnectorResult<Pulsar<TokioExecutor>> {
573        let mut pulsar_builder = Pulsar::builder(&self.service_url, TokioExecutor);
574        let mut _temp_file = None; // Keep temp file alive
575
576        if let Some(oauth) = oauth.as_ref() {
577            let (credentials_url, temp_file) = self
578                .resolve_pulsar_credentials_url(oauth, aws_auth_props)
579                .await?;
580            _temp_file = temp_file;
581
582            let auth_params = OAuth2Params {
583                issuer_url: oauth.issuer_url.clone(),
584                credentials_url,
585                audience: Some(oauth.audience.clone()),
586                scope: oauth.scope.clone(),
587            };
588
589            pulsar_builder = pulsar_builder
590                .with_auth_provider(OAuth2Authentication::client_credentials(auth_params));
591        } else if let Some(auth_token) = &self.auth_token {
592            pulsar_builder = pulsar_builder.with_auth(Authentication {
593                name: "token".to_owned(),
594                data: Vec::from(auth_token.as_str()),
595            });
596        }
597
598        if let Some(operation_retry_options) = operation_retry_options {
599            tracing::info!(
600                max_retries = ?operation_retry_options.max_retries,
601                retry_delay_ms = operation_retry_options.retry_delay.as_millis(),
602                "applying Pulsar source operation retry override"
603            );
604            pulsar_builder = pulsar_builder.with_operation_retry_options(operation_retry_options);
605        }
606
607        let res = pulsar_builder.build().await.map_err(|e| anyhow!(e))?;
608        drop(_temp_file); // Explicitly drop temp file after client is built
609        Ok(res)
610    }
611
612    pub(crate) async fn resolve_pulsar_credentials_url(
613        &self,
614        oauth: &PulsarOauthCommon,
615        aws_auth_props: &AwsAuthProps,
616    ) -> ConnectorResult<(String, Option<NamedTempFile>)> {
617        // Try parsing as URL first
618        if let Ok(url) = Url::parse(&oauth.credentials_url) {
619            return self
620                .handle_pulsar_credentials_url(&url, aws_auth_props)
621                .await;
622        }
623
624        // If not a valid URL, check if it's an absolute file path
625        let path = Path::new(&oauth.credentials_url);
626        if !path.is_absolute() {
627            bail!("credentials_url must be a valid URL (s3://, file://) or an absolute file path");
628        }
629
630        // Verify the file exists
631        if !tokio::fs::try_exists(&oauth.credentials_url)
632            .await
633            .unwrap_or(false)
634        {
635            bail!("credentials file does not exist: {}", oauth.credentials_url);
636        }
637
638        // Return absolute path with file:// prefix
639        Ok((format!("file://{}", oauth.credentials_url), None))
640    }
641
642    pub(crate) async fn handle_pulsar_credentials_url(
643        &self,
644        url: &Url,
645        aws_auth_props: &AwsAuthProps,
646    ) -> ConnectorResult<(String, Option<NamedTempFile>)> {
647        match url.scheme() {
648            "s3" => {
649                let credentials = load_file_descriptor_from_s3(url, aws_auth_props).await?;
650                let temp_file = create_credential_temp_file(&credentials)
651                    .context("failed to create temp file for pulsar credentials")?;
652
653                let temp_path = temp_file
654                    .path()
655                    .to_str()
656                    .context("temp file path is not valid UTF-8")?;
657
658                Ok((format!("file://{}", temp_path), Some(temp_file)))
659            }
660            "file" => Ok((url.to_string(), None)),
661            _ => bail!(
662                "invalid credentials_url scheme '{}', only file://, s3://, and absolute file paths are supported",
663                url.scheme()
664            ),
665        }
666    }
667}
668
669#[serde_as]
670#[derive(Deserialize, Debug, Clone, WithOptions)]
671pub struct KinesisCommon {
672    #[serde(rename = "stream", alias = "kinesis.stream.name")]
673    pub stream_name: String,
674    #[serde(rename = "aws.region", alias = "kinesis.stream.region")]
675    pub stream_region: String,
676    #[serde(rename = "endpoint", alias = "kinesis.endpoint")]
677    pub endpoint: Option<String>,
678    #[serde(
679        rename = "aws.credentials.access_key_id",
680        alias = "kinesis.credentials.access"
681    )]
682    pub credentials_access_key: Option<String>,
683    #[serde(
684        rename = "aws.credentials.secret_access_key",
685        alias = "kinesis.credentials.secret"
686    )]
687    pub credentials_secret_access_key: Option<String>,
688    #[serde(
689        rename = "aws.credentials.session_token",
690        alias = "kinesis.credentials.session_token"
691    )]
692    pub session_token: Option<String>,
693    #[serde(rename = "aws.credentials.role.arn", alias = "kinesis.assumerole.arn")]
694    pub assume_role_arn: Option<String>,
695    #[serde(
696        rename = "aws.credentials.role.external_id",
697        alias = "kinesis.assumerole.external_id"
698    )]
699    pub assume_role_external_id: Option<String>,
700
701    // sdk options
702    #[serde(
703        rename = "kinesis.sdk.connect_timeout_ms",
704        default = "kinesis_default_connect_timeout_ms"
705    )]
706    #[serde_as(as = "DisplayFromStr")]
707    pub sdk_connect_timeout_ms: u64,
708
709    #[serde(
710        rename = "kinesis.sdk.read_timeout_ms",
711        default = "kinesis_default_read_timeout_ms"
712    )]
713    #[serde_as(as = "DisplayFromStr")]
714    pub sdk_read_timeout_ms: u64,
715
716    #[serde(
717        rename = "kinesis.sdk.operation_timeout_ms",
718        default = "kinesis_default_operation_timeout_ms"
719    )]
720    #[serde_as(as = "DisplayFromStr")]
721    pub sdk_operation_timeout_ms: u64,
722
723    #[serde(
724        rename = "kinesis.sdk.operation_attempt_timeout_ms",
725        default = "kinesis_default_operation_attempt_timeout_ms"
726    )]
727    #[serde_as(as = "DisplayFromStr")]
728    pub sdk_operation_attempt_timeout_ms: u64,
729
730    #[serde(
731        rename = "kinesis.sdk.max_retry_limit",
732        default = "kinesis_default_max_retry_limit"
733    )]
734    #[serde_as(as = "DisplayFromStr")]
735    pub sdk_max_retry_limit: u32,
736
737    #[serde(
738        rename = "kinesis.sdk.init_backoff_ms",
739        default = "kinesis_default_init_backoff_ms"
740    )]
741    #[serde_as(as = "DisplayFromStr")]
742    pub sdk_init_backoff_ms: u64,
743
744    #[serde(
745        rename = "kinesis.sdk.max_backoff_ms",
746        default = "kinesis_default_max_backoff_ms"
747    )]
748    #[serde_as(as = "DisplayFromStr")]
749    pub sdk_max_backoff_ms: u64,
750}
751
752#[derive(Debug)]
753pub struct KinesisAsyncSleepImpl;
754
755impl AsyncSleep for KinesisAsyncSleepImpl {
756    fn sleep(&self, duration: Duration) -> Sleep {
757        Sleep::new(async move { tokio::time::sleep(duration).await })
758    }
759}
760
761const fn kinesis_default_connect_timeout_ms() -> u64 {
762    10000
763}
764
765const fn kinesis_default_read_timeout_ms() -> u64 {
766    10000
767}
768
769const fn kinesis_default_operation_timeout_ms() -> u64 {
770    10000
771}
772
773const fn kinesis_default_operation_attempt_timeout_ms() -> u64 {
774    10000
775}
776
777const fn kinesis_default_init_backoff_ms() -> u64 {
778    1000
779}
780
781const fn kinesis_default_max_backoff_ms() -> u64 {
782    20000
783}
784
785const fn kinesis_default_max_retry_limit() -> u32 {
786    3
787}
788
789impl EnforceSecret for KinesisCommon {
790    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
791        "kinesis.credentials.access",
792        "kinesis.credentials.secret",
793        "kinesis.credentials.session_token",
794    };
795}
796
797impl KinesisCommon {
798    pub(crate) async fn build_client(&self) -> ConnectorResult<KinesisClient> {
799        let config = AwsAuthProps {
800            region: Some(self.stream_region.clone()),
801            endpoint: self.endpoint.clone(),
802            access_key: self.credentials_access_key.clone(),
803            secret_key: self.credentials_secret_access_key.clone(),
804            session_token: self.session_token.clone(),
805            arn: self.assume_role_arn.clone(),
806            external_id: self.assume_role_external_id.clone(),
807            profile: Default::default(),
808            msk_signer_timeout_sec: Default::default(),
809        };
810        let aws_config = config.build_config().await?;
811        let mut builder = aws_sdk_kinesis::config::Builder::from(&aws_config);
812        {
813            // for timeout and retry config
814            let sleep_impl = SharedAsyncSleep::new(KinesisAsyncSleepImpl);
815            builder.set_sleep_impl(Some(sleep_impl));
816            let timeout_config = aws_smithy_types::timeout::TimeoutConfig::builder()
817                .connect_timeout(Duration::from_millis(self.sdk_connect_timeout_ms))
818                .read_timeout(Duration::from_millis(self.sdk_read_timeout_ms))
819                .operation_timeout(Duration::from_millis(self.sdk_operation_timeout_ms))
820                .operation_attempt_timeout(Duration::from_millis(
821                    self.sdk_operation_attempt_timeout_ms,
822                ))
823                .build();
824            builder.set_timeout_config(Some(timeout_config));
825
826            let retry_config = aws_smithy_types::retry::RetryConfig::standard()
827                .with_initial_backoff(Duration::from_millis(self.sdk_init_backoff_ms))
828                .with_max_backoff(Duration::from_millis(self.sdk_max_backoff_ms))
829                .with_max_attempts(self.sdk_max_retry_limit);
830            builder.set_retry_config(Some(retry_config));
831        }
832        if let Some(endpoint) = &config.endpoint {
833            builder = builder.endpoint_url(endpoint);
834        }
835        Ok(KinesisClient::from_conf(builder.build()))
836    }
837}
838
839/// Connection properties for NATS, used as a cache key for shared clients.
840/// This includes all properties that affect the connection itself (not stream/subject specific).
841#[derive(Debug, Clone, PartialEq, Eq, Hash)]
842pub struct NatsConnectionProps {
843    pub server_url: String,
844    pub connect_mode: String,
845    pub user: Option<String>,
846    pub password: Option<String>,
847    pub jwt: Option<String>,
848    pub nkey: Option<String>,
849}
850
851/// Shared NATS client cache.
852/// Client connections are cached as `Weak` pointers in the cache.
853/// NATS Connector can access this cache to reuse existing client connections,
854/// and avoid exhausting host machine ports.
855/// When reading from the cache, the connector should `upgrade` the weak pointer to an `Arc` reference.
856/// After all strong (Arc) references are dropped, the client connection will be cleaned up.
857/// Cache eviction naturally takes care of the dangling weak pointers.
858pub static SHARED_NATS_CLIENT: LazyLock<MokaCache<NatsConnectionProps, Weak<async_nats::Client>>> =
859    LazyLock::new(|| MokaCache::builder().build());
860
861#[serde_as]
862#[derive(Deserialize, Debug, Clone, WithOptions)]
863pub struct NatsCommon {
864    #[serde(rename = "server_url")]
865    pub server_url: String,
866    #[serde(rename = "subject")]
867    pub subject: String,
868    #[serde(rename = "connect_mode")]
869    pub connect_mode: String,
870    #[serde(rename = "username")]
871    pub user: Option<String>,
872    #[serde(rename = "password")]
873    pub password: Option<String>,
874    #[serde(rename = "jwt")]
875    pub jwt: Option<String>,
876    #[serde(rename = "nkey")]
877    pub nkey: Option<String>,
878    #[serde(rename = "max_bytes")]
879    #[serde_as(as = "Option<DisplayFromStr>")]
880    pub max_bytes: Option<i64>,
881    #[serde(rename = "max_messages")]
882    #[serde_as(as = "Option<DisplayFromStr>")]
883    pub max_messages: Option<i64>,
884    #[serde(rename = "max_messages_per_subject")]
885    #[serde_as(as = "Option<DisplayFromStr>")]
886    pub max_messages_per_subject: Option<i64>,
887    #[serde(rename = "max_consumers")]
888    #[serde_as(as = "Option<DisplayFromStr>")]
889    pub max_consumers: Option<i32>,
890    #[serde(rename = "max_message_size")]
891    #[serde_as(as = "Option<DisplayFromStr>")]
892    pub max_message_size: Option<i32>,
893    #[serde(rename = "allow_create_stream", default)]
894    #[serde_as(as = "DisplayFromStr")]
895    pub allow_create_stream: bool,
896}
897
898impl EnforceSecret for NatsCommon {
899    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
900        "password",
901        "jwt",
902        "nkey",
903    };
904}
905
906impl NatsCommon {
907    /// Extract connection properties that can be used as a cache key.
908    pub fn connection_props(&self) -> NatsConnectionProps {
909        NatsConnectionProps {
910            server_url: self.server_url.clone(),
911            connect_mode: self.connect_mode.clone(),
912            user: self.user.clone(),
913            password: self.password.clone(),
914            jwt: self.jwt.clone(),
915            nkey: self.nkey.clone(),
916        }
917    }
918
919    /// Build a new NATS client without caching.
920    async fn build_client_inner(&self) -> ConnectorResult<async_nats::Client> {
921        let mut connect_options = async_nats::ConnectOptions::new();
922        match self.connect_mode.as_str() {
923            "user_and_password" => {
924                if let (Some(v_user), Some(v_password)) =
925                    (self.user.as_ref(), self.password.as_ref())
926                {
927                    connect_options =
928                        connect_options.user_and_password(v_user.into(), v_password.into())
929                } else {
930                    bail!("nats connect mode is user_and_password, but user or password is empty");
931                }
932            }
933
934            "credential" => {
935                if let (Some(v_nkey), Some(v_jwt)) = (self.nkey.as_ref(), self.jwt.as_ref()) {
936                    connect_options = connect_options
937                        .credentials(&self.create_credential(v_nkey, v_jwt)?)
938                        .expect("failed to parse static creds")
939                } else {
940                    bail!("nats connect mode is credential, but nkey or jwt is empty");
941                }
942            }
943            "plain" => {}
944            _ => {
945                bail!("nats connect mode only accepts user_and_password/credential/plain");
946            }
947        };
948
949        let servers = self.server_url.split(',').collect::<Vec<&str>>();
950        let client = connect_options
951            .connect(
952                servers
953                    .iter()
954                    .map(|url| url.parse())
955                    .collect::<Result<Vec<async_nats::ServerAddr>, _>>()?,
956            )
957            .await
958            .context("build nats client error")
959            .map_err(SinkError::Nats)?;
960        Ok(client)
961    }
962
963    /// Build a NATS client, attempting to reuse an existing cached client if available.
964    /// See `SHARED_NATS_CLIENT` for more details.
965    pub(crate) async fn build_client(&self) -> ConnectorResult<Arc<async_nats::Client>> {
966        let connection_props = self.connection_props();
967        let mut client: Option<Arc<async_nats::Client>> = None;
968
969        SHARED_NATS_CLIENT
970            .entry_by_ref(&connection_props)
971            .and_try_compute_with::<_, _, crate::error::ConnectorError>(|maybe_entry| async {
972                if let Some(entry) = maybe_entry
973                    && let entry_value = entry.into_value()
974                    && let Some(existing_client) = entry_value.upgrade()
975                {
976                    match existing_client.connection_state() {
977                        async_nats::connection::State::Connected => {
978                            tracing::info!("reuse existing nats client for {}", self.server_url);
979                            client = Some(existing_client);
980                            return Ok(Op::Nop);
981                        }
982                        _ => {
983                            tracing::warn!(
984                                server_url = self.server_url,
985                                "existing nats client is not connected",
986                            );
987                        }
988                    }
989                }
990                tracing::info!(
991                    server_url = self.server_url,
992                    "no cached client, or client disconnected, building new nats client"
993                );
994                let new_client = Arc::new(self.build_client_inner().await?);
995                client = Some(new_client.clone());
996                Ok(Op::Put(Arc::downgrade(&new_client)))
997            })
998            .await?;
999
1000        Ok(client.expect("client should be set"))
1001    }
1002
1003    pub(crate) async fn build_context(&self) -> ConnectorResult<jetstream::Context> {
1004        let client = self.build_client().await?;
1005        let jetstream = async_nats::jetstream::new((*client).clone());
1006        Ok(jetstream)
1007    }
1008
1009    /// Build a `JetStream` context using a pre-existing client.
1010    pub(crate) fn build_context_from_client(
1011        client: &Arc<async_nats::Client>,
1012    ) -> jetstream::Context {
1013        async_nats::jetstream::new((**client).clone())
1014    }
1015
1016    /// Build a NATS `JetStream` consumer.
1017    ///
1018    /// If `existing_client` is provided, it will be used instead of creating/fetching a new one.
1019    /// This allows callers to reuse a client they already hold.
1020    pub(crate) async fn build_consumer(
1021        &self,
1022        stream: String,
1023        durable_consumer_name: String,
1024        split_id: String,
1025        start_sequence: NatsOffset,
1026        mut config: jetstream::consumer::pull::Config,
1027        existing_client: Option<Arc<async_nats::Client>>,
1028    ) -> ConnectorResult<(
1029        async_nats::jetstream::consumer::Consumer<async_nats::jetstream::consumer::pull::Config>,
1030        Arc<async_nats::Client>,
1031    )> {
1032        let client = match existing_client {
1033            Some(c) => c,
1034            None => self.build_client().await?,
1035        };
1036        let context = Self::build_context_from_client(&client);
1037        let stream = self.build_or_get_stream(context.clone(), stream).await?;
1038        let subject_name = self
1039            .subject
1040            .replace(',', "-")
1041            .replace(['.', '>', '*', ' ', '\t'], "_");
1042        let name = format!("risingwave-consumer-{}-{}", subject_name, split_id);
1043
1044        let deliver_policy = match start_sequence {
1045            NatsOffset::Earliest => DeliverPolicy::All,
1046            NatsOffset::Latest => DeliverPolicy::New,
1047            NatsOffset::SequenceNumber(v) => {
1048                // for compatibility, we do not write to any state table now
1049                let parsed = v
1050                    .parse::<u64>()
1051                    .context("failed to parse nats offset as sequence number")?;
1052                DeliverPolicy::ByStartSequence {
1053                    start_sequence: 1 + parsed,
1054                }
1055            }
1056            NatsOffset::Timestamp(v) => DeliverPolicy::ByStartTime {
1057                start_time: OffsetDateTime::from_unix_timestamp_nanos(v as i128 * 1_000_000)
1058                    .context("invalid timestamp for nats offset")?,
1059            },
1060            NatsOffset::None => DeliverPolicy::All,
1061        };
1062
1063        let consumer = match stream.get_consumer(&name).await {
1064            Ok(consumer) => consumer,
1065            _ => {
1066                stream
1067                    .get_or_create_consumer(&name, {
1068                        config.deliver_policy = deliver_policy;
1069                        config.durable_name = Some(durable_consumer_name);
1070                        config.filter_subjects =
1071                            self.subject.split(',').map(|s| s.to_owned()).collect();
1072                        config
1073                    })
1074                    .await?
1075            }
1076        };
1077        Ok((consumer, client))
1078    }
1079
1080    pub(crate) async fn build_or_get_stream(
1081        &self,
1082        jetstream: jetstream::Context,
1083        stream_str: String,
1084    ) -> ConnectorResult<jetstream::stream::Stream> {
1085        let subjects: Vec<String> = self.subject.split(',').map(|s| s.to_owned()).collect();
1086
1087        // In `SourceEnumerator`, we may create a stream
1088        // In `SourceReader`, the desired stream MUST exist
1089        if let Ok(mut stream_instance) = jetstream.get_stream(&stream_str).await {
1090            tracing::info!(
1091                "load existing nats stream ({:?}) with config {:?}",
1092                stream_str,
1093                stream_instance.info().await?
1094            );
1095            return Ok(stream_instance);
1096        }
1097
1098        if !self.allow_create_stream {
1099            return Err(anyhow!(
1100                "stream {} not found, set `allow_create_stream` to true to create a stream",
1101                stream_str
1102            )
1103            .into());
1104        }
1105
1106        let mut config = jetstream::stream::Config {
1107            name: stream_str.clone(),
1108            max_bytes: 1000000,
1109            subjects,
1110            ..Default::default()
1111        };
1112        if let Some(v) = self.max_bytes {
1113            config.max_bytes = v;
1114        }
1115        if let Some(v) = self.max_messages {
1116            config.max_messages = v;
1117        }
1118        if let Some(v) = self.max_messages_per_subject {
1119            config.max_messages_per_subject = v;
1120        }
1121        if let Some(v) = self.max_consumers {
1122            config.max_consumers = v;
1123        }
1124        if let Some(v) = self.max_message_size {
1125            config.max_message_size = v;
1126        }
1127        tracing::info!(
1128            "create nats stream ({:?}) with config {:?}",
1129            &stream_str,
1130            config
1131        );
1132        let stream = jetstream.get_or_create_stream(config).await?;
1133        Ok(stream)
1134    }
1135
1136    pub(crate) fn create_credential(&self, seed: &str, jwt: &str) -> ConnectorResult<String> {
1137        let creds = format!(
1138            "-----BEGIN NATS USER JWT-----\n{}\n------END NATS USER JWT------\n\n\
1139                         ************************* IMPORTANT *************************\n\
1140                         NKEY Seed printed below can be used to sign and prove identity.\n\
1141                         NKEYs are sensitive and should be treated as secrets.\n\n\
1142                         -----BEGIN USER NKEY SEED-----\n{}\n------END USER NKEY SEED------\n\n\
1143                         *************************************************************",
1144            jwt, seed
1145        );
1146        Ok(creds)
1147    }
1148}
1149
1150pub(crate) fn load_certs(
1151    certificates: &str,
1152) -> ConnectorResult<Vec<rustls_pki_types::CertificateDer<'static>>> {
1153    let cert_bytes = if let Some(path) = certificates.strip_prefix("fs://") {
1154        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
1155    } else {
1156        certificates.as_bytes().to_owned()
1157    };
1158
1159    CertificateDer::pem_slice_iter(&cert_bytes)
1160        .collect::<Result<Vec<_>, _>>()
1161        .context("Failed to parse certificates")
1162        .map_err(Into::into)
1163}
1164
1165pub(crate) fn load_private_key(
1166    certificate: &str,
1167) -> ConnectorResult<rustls_pki_types::PrivateKeyDer<'static>> {
1168    let cert_bytes = if let Some(path) = certificate.strip_prefix("fs://") {
1169        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
1170    } else {
1171        certificate.as_bytes().to_owned()
1172    };
1173
1174    let cert = PrivatePkcs8KeyDer::pem_slice_iter(&cert_bytes)
1175        .next()
1176        .ok_or_else(|| anyhow!("No private key found"))?
1177        .context("Failed to parse private key")?;
1178    Ok(cert.into())
1179}
1180
1181#[serde_as]
1182#[derive(Deserialize, Debug, Clone, WithOptions)]
1183pub struct MongodbCommon {
1184    /// The URL of `MongoDB`
1185    #[serde(rename = "mongodb.url")]
1186    pub connect_uri: String,
1187    /// The collection name where data should be written to or read from. For sinks, the format is
1188    /// `db_name.collection_name`. Data can also be written to dynamic collections, see `collection.name.field`
1189    /// for more information.
1190    #[serde(rename = "collection.name")]
1191    pub collection_name: String,
1192}
1193
1194impl EnforceSecret for MongodbCommon {
1195    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
1196        "mongodb.url"
1197    };
1198}
1199
1200impl MongodbCommon {
1201    pub(crate) async fn build_client(&self) -> ConnectorResult<mongodb::Client> {
1202        let client = mongodb::Client::with_uri_str(&self.connect_uri).await?;
1203
1204        Ok(client)
1205    }
1206}