risingwave_connector/connector_common/
common.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::hash::Hash;
17use std::io::Write;
18use std::path::Path;
19use std::sync::{Arc, LazyLock, Weak};
20use std::time::Duration;
21
22use anyhow::{Context, anyhow};
23use async_nats::jetstream::consumer::DeliverPolicy;
24use async_nats::jetstream::{self};
25use aws_sdk_kinesis::Client as KinesisClient;
26use aws_sdk_kinesis::config::{AsyncSleep, SharedAsyncSleep, Sleep};
27use moka::future::Cache as MokaCache;
28use moka::ops::compute::Op;
29use phf::{Set, phf_set};
30use pulsar::authentication::oauth2::{OAuth2Authentication, OAuth2Params};
31use pulsar::{Authentication, OperationRetryOptions, Pulsar, TokioExecutor};
32use rdkafka::ClientConfig;
33use risingwave_common::bail;
34use rustls_pki_types::pem::PemObject;
35use rustls_pki_types::{CertificateDer, PrivatePkcs8KeyDer};
36use serde::Deserialize;
37use serde_with::json::JsonString;
38use serde_with::{DisplayFromStr, serde_as};
39use tempfile::NamedTempFile;
40use time::OffsetDateTime;
41use url::Url;
42use with_options::WithOptions;
43
44use crate::aws_utils::load_file_descriptor_from_s3;
45use crate::deserialize_duration_from_string;
46use crate::enforce_secret::EnforceSecret;
47use crate::error::ConnectorResult;
48use crate::sink::SinkError;
49use crate::source::nats::source::NatsOffset;
50
51pub const PRIVATE_LINK_BROKER_REWRITE_MAP_KEY: &str = "broker.rewrite.endpoints";
52pub const PRIVATE_LINK_TARGETS_KEY: &str = "privatelink.targets";
53
54const AWS_MSK_IAM_AUTH: &str = "AWS_MSK_IAM";
55
56/// The environment variable to disable using default credential from environment.
57/// It's recommended to set this variable to `true` in cloud hosting environment.
58pub const DISABLE_DEFAULT_CREDENTIAL: &str = "DISABLE_DEFAULT_CREDENTIAL";
59
60#[derive(Debug, Clone, Deserialize)]
61pub struct AwsPrivateLinkItem {
62    pub az_id: Option<String>,
63    pub port: u16,
64}
65
66use aws_config::default_provider::region::DefaultRegionChain;
67use aws_config::sts::AssumeRoleProvider;
68use aws_credential_types::provider::SharedCredentialsProvider;
69use aws_types::SdkConfig;
70use aws_types::region::Region;
71use risingwave_common::util::env_var::env_var_is_true;
72
73/// A flatten config map for aws auth.
74#[derive(Deserialize, Debug, Clone, WithOptions, PartialEq)]
75pub struct AwsAuthProps {
76    #[serde(rename = "aws.region", alias = "region", alias = "s3.region")]
77    pub region: Option<String>,
78
79    #[serde(
80        rename = "aws.endpoint_url",
81        alias = "endpoint_url",
82        alias = "endpoint",
83        alias = "s3.endpoint"
84    )]
85    pub endpoint: Option<String>,
86    #[serde(
87        rename = "aws.credentials.access_key_id",
88        alias = "access_key",
89        alias = "s3.access.key"
90    )]
91    pub access_key: Option<String>,
92    #[serde(
93        rename = "aws.credentials.secret_access_key",
94        alias = "secret_key",
95        alias = "s3.secret.key"
96    )]
97    pub secret_key: Option<String>,
98    #[serde(rename = "aws.credentials.session_token", alias = "session_token")]
99    pub session_token: Option<String>,
100    /// IAM role
101    #[serde(rename = "aws.credentials.role.arn", alias = "arn")]
102    pub arn: Option<String>,
103    /// external ID in IAM role trust policy
104    #[serde(rename = "aws.credentials.role.external_id", alias = "external_id")]
105    pub external_id: Option<String>,
106    #[serde(rename = "aws.profile", alias = "profile")]
107    pub profile: Option<String>,
108    #[serde(rename = "aws.msk.signer_timeout_sec")]
109    pub msk_signer_timeout_sec: Option<u64>,
110}
111
112impl EnforceSecret for AwsAuthProps {
113    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
114        "access_key",
115        "aws.credentials.access_key_id",
116        "s3.access.key",
117        "secret_key",
118        "aws.credentials.secret_access_key",
119        "s3.secret.key",
120        "session_token",
121        "aws.credentials.session_token",
122    };
123}
124
125impl AwsAuthProps {
126    async fn build_region(&self) -> ConnectorResult<Region> {
127        if let Some(region_name) = &self.region {
128            Ok(Region::new(region_name.clone()))
129        } else {
130            let mut region_chain = DefaultRegionChain::builder();
131            if let Some(profile_name) = &self.profile {
132                region_chain = region_chain.profile_name(profile_name);
133            }
134
135            Ok(region_chain
136                .build()
137                .region()
138                .await
139                .context("region should be provided")?)
140        }
141    }
142
143    async fn build_credential_provider(&self) -> ConnectorResult<SharedCredentialsProvider> {
144        if self.access_key.is_some() && self.secret_key.is_some() {
145            Ok(SharedCredentialsProvider::new(
146                aws_credential_types::Credentials::from_keys(
147                    self.access_key.as_ref().unwrap(),
148                    self.secret_key.as_ref().unwrap(),
149                    self.session_token.clone(),
150                ),
151            ))
152        } else if !env_var_is_true(DISABLE_DEFAULT_CREDENTIAL) {
153            Ok(SharedCredentialsProvider::new(
154                aws_config::default_provider::credentials::default_provider().await,
155            ))
156        } else {
157            bail!("Both \"access_key\" and \"secret_key\" are required.")
158        }
159    }
160
161    async fn with_role_provider(
162        &self,
163        credential: SharedCredentialsProvider,
164    ) -> ConnectorResult<SharedCredentialsProvider> {
165        if let Some(role_name) = &self.arn {
166            let region = self.build_region().await?;
167            let mut role = AssumeRoleProvider::builder(role_name)
168                .session_name("RisingWave")
169                .region(region);
170            if let Some(id) = &self.external_id {
171                role = role.external_id(id);
172            }
173            let provider = role.build_from_provider(credential).await;
174            Ok(SharedCredentialsProvider::new(provider))
175        } else {
176            Ok(credential)
177        }
178    }
179
180    pub async fn build_config(&self) -> ConnectorResult<SdkConfig> {
181        let region = self.build_region().await?;
182        let credentials_provider = self
183            .with_role_provider(self.build_credential_provider().await?)
184            .await?;
185        let mut config_loader = aws_config::from_env()
186            .region(region)
187            .credentials_provider(credentials_provider);
188
189        if let Some(endpoint) = self.endpoint.as_ref() {
190            config_loader = config_loader.endpoint_url(endpoint);
191        }
192
193        Ok(config_loader.load().await)
194    }
195}
196
197#[serde_as]
198#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
199pub struct KafkaConnectionProps {
200    #[serde(rename = "properties.bootstrap.server", alias = "kafka.brokers")]
201    pub brokers: String,
202
203    /// Security protocol used for RisingWave to communicate with Kafka brokers. Could be
204    /// PLAINTEXT, SSL, `SASL_PLAINTEXT` or `SASL_SSL`.
205    #[serde(rename = "properties.security.protocol")]
206    #[with_option(allow_alter_on_fly)]
207    security_protocol: Option<String>,
208
209    #[serde(rename = "properties.ssl.endpoint.identification.algorithm")]
210    #[with_option(allow_alter_on_fly)]
211    ssl_endpoint_identification_algorithm: Option<String>,
212
213    // For the properties below, please refer to [librdkafka](https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md) for more information.
214    /// Path to CA certificate file for verifying the broker's key.
215    #[serde(rename = "properties.ssl.ca.location")]
216    #[with_option(allow_alter_on_fly)]
217    ssl_ca_location: Option<String>,
218
219    /// CA certificate string (PEM format) for verifying the broker's key.
220    #[serde(rename = "properties.ssl.ca.pem")]
221    #[with_option(allow_alter_on_fly)]
222    ssl_ca_pem: Option<String>,
223
224    /// Path to client's certificate file (PEM).
225    #[serde(rename = "properties.ssl.certificate.location")]
226    #[with_option(allow_alter_on_fly)]
227    ssl_certificate_location: Option<String>,
228
229    /// Client's public key string (PEM format) used for authentication.
230    #[serde(rename = "properties.ssl.certificate.pem")]
231    #[with_option(allow_alter_on_fly)]
232    ssl_certificate_pem: Option<String>,
233
234    /// Path to client's private key file (PEM).
235    #[serde(rename = "properties.ssl.key.location")]
236    #[with_option(allow_alter_on_fly)]
237    ssl_key_location: Option<String>,
238
239    /// Client's private key string (PEM format) used for authentication.
240    #[serde(rename = "properties.ssl.key.pem")]
241    #[with_option(allow_alter_on_fly)]
242    ssl_key_pem: Option<String>,
243
244    /// Passphrase of client's private key.
245    #[serde(rename = "properties.ssl.key.password")]
246    #[with_option(allow_alter_on_fly)]
247    ssl_key_password: Option<String>,
248
249    /// SASL mechanism if SASL is enabled. Currently support PLAIN, SCRAM, GSSAPI, and `AWS_MSK_IAM`.
250    #[serde(rename = "properties.sasl.mechanism")]
251    #[with_option(allow_alter_on_fly)]
252    sasl_mechanism: Option<String>,
253
254    /// SASL username for SASL/PLAIN and SASL/SCRAM.
255    #[serde(rename = "properties.sasl.username")]
256    #[with_option(allow_alter_on_fly)]
257    sasl_username: Option<String>,
258
259    /// SASL password for SASL/PLAIN and SASL/SCRAM.
260    #[serde(rename = "properties.sasl.password")]
261    #[with_option(allow_alter_on_fly)]
262    sasl_password: Option<String>,
263
264    /// Kafka server's Kerberos principal name under SASL/GSSAPI, not including /hostname@REALM.
265    #[serde(rename = "properties.sasl.kerberos.service.name")]
266    sasl_kerberos_service_name: Option<String>,
267
268    /// Path to client's Kerberos keytab file under SASL/GSSAPI.
269    #[serde(rename = "properties.sasl.kerberos.keytab")]
270    sasl_kerberos_keytab: Option<String>,
271
272    /// Client's Kerberos principal name under SASL/GSSAPI.
273    #[serde(rename = "properties.sasl.kerberos.principal")]
274    sasl_kerberos_principal: Option<String>,
275
276    /// Shell command to refresh or acquire the client's Kerberos ticket under SASL/GSSAPI.
277    #[serde(rename = "properties.sasl.kerberos.kinit.cmd")]
278    sasl_kerberos_kinit_cmd: Option<String>,
279
280    /// Minimum time in milliseconds between key refresh attempts under SASL/GSSAPI.
281    #[serde(rename = "properties.sasl.kerberos.min.time.before.relogin")]
282    sasl_kerberos_min_time_before_relogin: Option<String>,
283
284    /// Configurations for SASL/OAUTHBEARER.
285    #[serde(rename = "properties.sasl.oauthbearer.config")]
286    sasl_oathbearer_config: Option<String>,
287}
288
289impl EnforceSecret for KafkaConnectionProps {
290    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
291        "properties.ssl.key.pem",
292        "properties.ssl.key.password",
293        "properties.sasl.password",
294    };
295}
296
297#[serde_as]
298#[derive(Debug, Clone, Deserialize, WithOptions)]
299pub struct KafkaCommon {
300    // connection related props are moved to `KafkaConnection`
301    #[serde(rename = "topic", alias = "kafka.topic")]
302    pub topic: String,
303
304    #[serde(
305        rename = "properties.sync.call.timeout",
306        deserialize_with = "deserialize_duration_from_string",
307        default = "default_kafka_sync_call_timeout"
308    )]
309    #[with_option(allow_alter_on_fly)]
310    pub sync_call_timeout: Duration,
311}
312
313#[serde_as]
314#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
315pub struct KafkaPrivateLinkCommon {
316    /// This is generated from `private_link_targets` and `private_link_endpoint` in frontend, instead of given by users.
317    #[serde(rename = "broker.rewrite.endpoints")]
318    #[serde_as(as = "Option<JsonString>")]
319    pub broker_rewrite_map: Option<BTreeMap<String, String>>,
320}
321
322const fn default_kafka_sync_call_timeout() -> Duration {
323    Duration::from_secs(5)
324}
325
326const fn default_socket_keepalive_enable() -> bool {
327    true
328}
329
330#[serde_as]
331#[derive(Debug, Clone, Deserialize, WithOptions)]
332pub struct RdKafkaPropertiesCommon {
333    /// Maximum Kafka protocol request message size. Due to differing framing overhead between
334    /// protocol versions the producer is unable to reliably enforce a strict max message limit at
335    /// produce time and may exceed the maximum size by one message in protocol `ProduceRequests`,
336    /// the broker will enforce the topic's max.message.bytes limit
337    #[serde(rename = "properties.message.max.bytes")]
338    #[serde_as(as = "Option<DisplayFromStr>")]
339    #[with_option(allow_alter_on_fly)]
340    pub message_max_bytes: Option<usize>,
341
342    /// Maximum Kafka protocol response message size. This serves as a safety precaution to avoid
343    /// memory exhaustion in case of protocol hickups. This value must be at least fetch.max.bytes
344    /// + 512 to allow for protocol overhead; the value is adjusted automatically unless the
345    /// configuration property is explicitly set.
346    #[serde(rename = "properties.receive.message.max.bytes")]
347    #[serde_as(as = "Option<DisplayFromStr>")]
348    #[with_option(allow_alter_on_fly)]
349    pub receive_message_max_bytes: Option<usize>,
350
351    #[serde(rename = "properties.statistics.interval.ms")]
352    #[serde_as(as = "Option<DisplayFromStr>")]
353    #[with_option(allow_alter_on_fly)]
354    pub statistics_interval_ms: Option<usize>,
355
356    /// Client identifier
357    #[serde(rename = "properties.client.id")]
358    #[serde_as(as = "Option<DisplayFromStr>")]
359    #[with_option(allow_alter_on_fly)]
360    pub client_id: Option<String>,
361
362    #[serde(rename = "properties.enable.ssl.certificate.verification")]
363    #[serde_as(as = "Option<DisplayFromStr>")]
364    #[with_option(allow_alter_on_fly)]
365    pub enable_ssl_certificate_verification: Option<bool>,
366
367    #[serde(
368        rename = "properties.socket.keepalive.enable",
369        default = "default_socket_keepalive_enable"
370    )]
371    #[serde_as(as = "DisplayFromStr")]
372    pub socket_keepalive_enable: bool,
373}
374
375impl RdKafkaPropertiesCommon {
376    pub(crate) fn set_client(&self, c: &mut rdkafka::ClientConfig) {
377        if let Some(v) = self.statistics_interval_ms {
378            c.set("statistics.interval.ms", v.to_string());
379        }
380        if let Some(v) = self.message_max_bytes {
381            c.set("message.max.bytes", v.to_string());
382        }
383        if let Some(v) = self.receive_message_max_bytes {
384            c.set("receive.message.max.bytes", v.to_string());
385        }
386        if let Some(v) = self.client_id.as_ref() {
387            c.set("client.id", v);
388        }
389        if let Some(v) = self.enable_ssl_certificate_verification {
390            c.set("enable.ssl.certificate.verification", v.to_string());
391        }
392        c.set(
393            "socket.keepalive.enable",
394            self.socket_keepalive_enable.to_string(),
395        );
396    }
397}
398
399impl KafkaConnectionProps {
400    #[cfg(test)]
401    pub fn test_default() -> Self {
402        Self {
403            brokers: "localhost:9092".to_owned(),
404            security_protocol: None,
405            ssl_ca_location: None,
406            ssl_certificate_location: None,
407            ssl_key_location: None,
408            ssl_ca_pem: None,
409            ssl_certificate_pem: None,
410            ssl_key_pem: None,
411            ssl_key_password: None,
412            ssl_endpoint_identification_algorithm: None,
413            sasl_mechanism: None,
414            sasl_username: None,
415            sasl_password: None,
416            sasl_kerberos_service_name: None,
417            sasl_kerberos_keytab: None,
418            sasl_kerberos_principal: None,
419            sasl_kerberos_kinit_cmd: None,
420            sasl_kerberos_min_time_before_relogin: None,
421            sasl_oathbearer_config: None,
422        }
423    }
424
425    pub(crate) fn set_security_properties(&self, config: &mut ClientConfig) {
426        // AWS_MSK_IAM
427        if self.is_aws_msk_iam() {
428            config.set("security.protocol", "SASL_SSL");
429            config.set("sasl.mechanism", "OAUTHBEARER");
430            return;
431        }
432
433        // Security protocol
434        if let Some(security_protocol) = self.security_protocol.as_ref() {
435            config.set("security.protocol", security_protocol);
436        }
437
438        // SSL
439        if let Some(ssl_ca_location) = self.ssl_ca_location.as_ref() {
440            config.set("ssl.ca.location", ssl_ca_location);
441        }
442        if let Some(ssl_ca_pem) = self.ssl_ca_pem.as_ref() {
443            config.set("ssl.ca.pem", ssl_ca_pem);
444        }
445        if let Some(ssl_certificate_location) = self.ssl_certificate_location.as_ref() {
446            config.set("ssl.certificate.location", ssl_certificate_location);
447        }
448        if let Some(ssl_certificate_pem) = self.ssl_certificate_pem.as_ref() {
449            config.set("ssl.certificate.pem", ssl_certificate_pem);
450        }
451        if let Some(ssl_key_location) = self.ssl_key_location.as_ref() {
452            config.set("ssl.key.location", ssl_key_location);
453        }
454        if let Some(ssl_key_pem) = self.ssl_key_pem.as_ref() {
455            config.set("ssl.key.pem", ssl_key_pem);
456        }
457        if let Some(ssl_key_password) = self.ssl_key_password.as_ref() {
458            config.set("ssl.key.password", ssl_key_password);
459        }
460        if let Some(ssl_endpoint_identification_algorithm) =
461            self.ssl_endpoint_identification_algorithm.as_ref()
462        {
463            // accept only `none` and `http` here, let the sdk do the check
464            config.set(
465                "ssl.endpoint.identification.algorithm",
466                ssl_endpoint_identification_algorithm,
467            );
468        }
469
470        // SASL mechanism
471        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref() {
472            config.set("sasl.mechanism", sasl_mechanism);
473        }
474
475        // SASL/PLAIN & SASL/SCRAM
476        if let Some(sasl_username) = self.sasl_username.as_ref() {
477            config.set("sasl.username", sasl_username);
478        }
479        if let Some(sasl_password) = self.sasl_password.as_ref() {
480            config.set("sasl.password", sasl_password);
481        }
482
483        // SASL/GSSAPI
484        if let Some(sasl_kerberos_service_name) = self.sasl_kerberos_service_name.as_ref() {
485            config.set("sasl.kerberos.service.name", sasl_kerberos_service_name);
486        }
487        if let Some(sasl_kerberos_keytab) = self.sasl_kerberos_keytab.as_ref() {
488            config.set("sasl.kerberos.keytab", sasl_kerberos_keytab);
489        }
490        if let Some(sasl_kerberos_principal) = self.sasl_kerberos_principal.as_ref() {
491            config.set("sasl.kerberos.principal", sasl_kerberos_principal);
492        }
493        if let Some(sasl_kerberos_kinit_cmd) = self.sasl_kerberos_kinit_cmd.as_ref() {
494            config.set("sasl.kerberos.kinit.cmd", sasl_kerberos_kinit_cmd);
495        }
496        if let Some(sasl_kerberos_min_time_before_relogin) =
497            self.sasl_kerberos_min_time_before_relogin.as_ref()
498        {
499            config.set(
500                "sasl.kerberos.min.time.before.relogin",
501                sasl_kerberos_min_time_before_relogin,
502            );
503        }
504
505        // SASL/OAUTHBEARER
506        if let Some(sasl_oathbearer_config) = self.sasl_oathbearer_config.as_ref() {
507            config.set("sasl.oauthbearer.config", sasl_oathbearer_config);
508        }
509        // Currently, we only support unsecured OAUTH.
510        config.set("enable.sasl.oauthbearer.unsecure.jwt", "true");
511    }
512
513    pub(crate) fn is_aws_msk_iam(&self) -> bool {
514        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref()
515            && sasl_mechanism == AWS_MSK_IAM_AUTH
516        {
517            true
518        } else {
519            false
520        }
521    }
522}
523
524#[derive(Clone, Debug, Deserialize, WithOptions)]
525pub struct PulsarCommon {
526    #[serde(rename = "topic", alias = "pulsar.topic")]
527    pub topic: String,
528
529    #[serde(rename = "service.url", alias = "pulsar.service.url")]
530    pub service_url: String,
531
532    #[serde(rename = "auth.token")]
533    pub auth_token: Option<String>,
534}
535
536impl EnforceSecret for PulsarCommon {
537    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
538        "pulsar.auth.token",
539    };
540}
541
542#[derive(Clone, Debug, Deserialize, WithOptions)]
543pub struct PulsarOauthCommon {
544    #[serde(rename = "oauth.issuer.url")]
545    pub issuer_url: String,
546
547    #[serde(rename = "oauth.credentials.url")]
548    pub credentials_url: String,
549
550    #[serde(rename = "oauth.audience")]
551    pub audience: String,
552
553    #[serde(rename = "oauth.scope")]
554    pub scope: Option<String>,
555}
556
557fn create_credential_temp_file(credentials: &[u8]) -> std::io::Result<NamedTempFile> {
558    let mut f = NamedTempFile::new()?;
559    f.write_all(credentials)?;
560    f.as_file().sync_all()?;
561    Ok(f)
562}
563
564impl PulsarCommon {
565    pub(crate) async fn build_client(
566        &self,
567        oauth: &Option<PulsarOauthCommon>,
568        aws_auth_props: &AwsAuthProps,
569        operation_retry_options: Option<OperationRetryOptions>,
570    ) -> ConnectorResult<Pulsar<TokioExecutor>> {
571        let mut pulsar_builder = Pulsar::builder(&self.service_url, TokioExecutor);
572        let mut _temp_file = None; // Keep temp file alive
573
574        if let Some(oauth) = oauth.as_ref() {
575            let (credentials_url, temp_file) = self
576                .resolve_pulsar_credentials_url(oauth, aws_auth_props)
577                .await?;
578            _temp_file = temp_file;
579
580            let auth_params = OAuth2Params {
581                issuer_url: oauth.issuer_url.clone(),
582                credentials_url,
583                audience: Some(oauth.audience.clone()),
584                scope: oauth.scope.clone(),
585            };
586
587            pulsar_builder = pulsar_builder
588                .with_auth_provider(OAuth2Authentication::client_credentials(auth_params));
589        } else if let Some(auth_token) = &self.auth_token {
590            pulsar_builder = pulsar_builder.with_auth(Authentication {
591                name: "token".to_owned(),
592                data: Vec::from(auth_token.as_str()),
593            });
594        }
595
596        if let Some(operation_retry_options) = operation_retry_options {
597            tracing::info!(
598                max_retries = ?operation_retry_options.max_retries,
599                retry_delay_ms = operation_retry_options.retry_delay.as_millis(),
600                "applying Pulsar source operation retry override"
601            );
602            pulsar_builder = pulsar_builder.with_operation_retry_options(operation_retry_options);
603        }
604
605        let res = pulsar_builder.build().await.map_err(|e| anyhow!(e))?;
606        drop(_temp_file); // Explicitly drop temp file after client is built
607        Ok(res)
608    }
609
610    pub(crate) async fn resolve_pulsar_credentials_url(
611        &self,
612        oauth: &PulsarOauthCommon,
613        aws_auth_props: &AwsAuthProps,
614    ) -> ConnectorResult<(String, Option<NamedTempFile>)> {
615        // Try parsing as URL first
616        if let Ok(url) = Url::parse(&oauth.credentials_url) {
617            return self
618                .handle_pulsar_credentials_url(&url, aws_auth_props)
619                .await;
620        }
621
622        // If not a valid URL, check if it's an absolute file path
623        let path = Path::new(&oauth.credentials_url);
624        if !path.is_absolute() {
625            bail!("credentials_url must be a valid URL (s3://, file://) or an absolute file path");
626        }
627
628        // Verify the file exists
629        if !tokio::fs::try_exists(&oauth.credentials_url)
630            .await
631            .unwrap_or(false)
632        {
633            bail!("credentials file does not exist: {}", oauth.credentials_url);
634        }
635
636        // Return absolute path with file:// prefix
637        Ok((format!("file://{}", oauth.credentials_url), None))
638    }
639
640    pub(crate) async fn handle_pulsar_credentials_url(
641        &self,
642        url: &Url,
643        aws_auth_props: &AwsAuthProps,
644    ) -> ConnectorResult<(String, Option<NamedTempFile>)> {
645        match url.scheme() {
646            "s3" => {
647                let credentials = load_file_descriptor_from_s3(url, aws_auth_props).await?;
648                let temp_file = create_credential_temp_file(&credentials)
649                    .context("failed to create temp file for pulsar credentials")?;
650
651                let temp_path = temp_file
652                    .path()
653                    .to_str()
654                    .context("temp file path is not valid UTF-8")?;
655
656                Ok((format!("file://{}", temp_path), Some(temp_file)))
657            }
658            "file" => Ok((url.to_string(), None)),
659            _ => bail!(
660                "invalid credentials_url scheme '{}', only file://, s3://, and absolute file paths are supported",
661                url.scheme()
662            ),
663        }
664    }
665}
666
667#[serde_as]
668#[derive(Deserialize, Debug, Clone, WithOptions)]
669pub struct KinesisCommon {
670    #[serde(rename = "stream", alias = "kinesis.stream.name")]
671    pub stream_name: String,
672    #[serde(rename = "aws.region", alias = "kinesis.stream.region")]
673    pub stream_region: String,
674    #[serde(rename = "endpoint", alias = "kinesis.endpoint")]
675    pub endpoint: Option<String>,
676    #[serde(
677        rename = "aws.credentials.access_key_id",
678        alias = "kinesis.credentials.access"
679    )]
680    pub credentials_access_key: Option<String>,
681    #[serde(
682        rename = "aws.credentials.secret_access_key",
683        alias = "kinesis.credentials.secret"
684    )]
685    pub credentials_secret_access_key: Option<String>,
686    #[serde(
687        rename = "aws.credentials.session_token",
688        alias = "kinesis.credentials.session_token"
689    )]
690    pub session_token: Option<String>,
691    #[serde(rename = "aws.credentials.role.arn", alias = "kinesis.assumerole.arn")]
692    pub assume_role_arn: Option<String>,
693    #[serde(
694        rename = "aws.credentials.role.external_id",
695        alias = "kinesis.assumerole.external_id"
696    )]
697    pub assume_role_external_id: Option<String>,
698
699    // sdk options
700    #[serde(
701        rename = "kinesis.sdk.connect_timeout_ms",
702        default = "kinesis_default_connect_timeout_ms"
703    )]
704    #[serde_as(as = "DisplayFromStr")]
705    pub sdk_connect_timeout_ms: u64,
706
707    #[serde(
708        rename = "kinesis.sdk.read_timeout_ms",
709        default = "kinesis_default_read_timeout_ms"
710    )]
711    #[serde_as(as = "DisplayFromStr")]
712    pub sdk_read_timeout_ms: u64,
713
714    #[serde(
715        rename = "kinesis.sdk.operation_timeout_ms",
716        default = "kinesis_default_operation_timeout_ms"
717    )]
718    #[serde_as(as = "DisplayFromStr")]
719    pub sdk_operation_timeout_ms: u64,
720
721    #[serde(
722        rename = "kinesis.sdk.operation_attempt_timeout_ms",
723        default = "kinesis_default_operation_attempt_timeout_ms"
724    )]
725    #[serde_as(as = "DisplayFromStr")]
726    pub sdk_operation_attempt_timeout_ms: u64,
727
728    #[serde(
729        rename = "kinesis.sdk.max_retry_limit",
730        default = "kinesis_default_max_retry_limit"
731    )]
732    #[serde_as(as = "DisplayFromStr")]
733    pub sdk_max_retry_limit: u32,
734
735    #[serde(
736        rename = "kinesis.sdk.init_backoff_ms",
737        default = "kinesis_default_init_backoff_ms"
738    )]
739    #[serde_as(as = "DisplayFromStr")]
740    pub sdk_init_backoff_ms: u64,
741
742    #[serde(
743        rename = "kinesis.sdk.max_backoff_ms",
744        default = "kinesis_default_max_backoff_ms"
745    )]
746    #[serde_as(as = "DisplayFromStr")]
747    pub sdk_max_backoff_ms: u64,
748}
749
750#[derive(Debug)]
751pub struct KinesisAsyncSleepImpl;
752
753impl AsyncSleep for KinesisAsyncSleepImpl {
754    fn sleep(&self, duration: Duration) -> Sleep {
755        Sleep::new(async move { tokio::time::sleep(duration).await })
756    }
757}
758
759const fn kinesis_default_connect_timeout_ms() -> u64 {
760    10000
761}
762
763const fn kinesis_default_read_timeout_ms() -> u64 {
764    10000
765}
766
767const fn kinesis_default_operation_timeout_ms() -> u64 {
768    10000
769}
770
771const fn kinesis_default_operation_attempt_timeout_ms() -> u64 {
772    10000
773}
774
775const fn kinesis_default_init_backoff_ms() -> u64 {
776    1000
777}
778
779const fn kinesis_default_max_backoff_ms() -> u64 {
780    20000
781}
782
783const fn kinesis_default_max_retry_limit() -> u32 {
784    3
785}
786
787impl EnforceSecret for KinesisCommon {
788    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
789        "kinesis.credentials.access",
790        "kinesis.credentials.secret",
791        "kinesis.credentials.session_token",
792    };
793}
794
795impl KinesisCommon {
796    pub(crate) async fn build_client(&self) -> ConnectorResult<KinesisClient> {
797        let config = AwsAuthProps {
798            region: Some(self.stream_region.clone()),
799            endpoint: self.endpoint.clone(),
800            access_key: self.credentials_access_key.clone(),
801            secret_key: self.credentials_secret_access_key.clone(),
802            session_token: self.session_token.clone(),
803            arn: self.assume_role_arn.clone(),
804            external_id: self.assume_role_external_id.clone(),
805            profile: Default::default(),
806            msk_signer_timeout_sec: Default::default(),
807        };
808        let aws_config = config.build_config().await?;
809        let mut builder = aws_sdk_kinesis::config::Builder::from(&aws_config);
810        {
811            // for timeout and retry config
812            let sleep_impl = SharedAsyncSleep::new(KinesisAsyncSleepImpl);
813            builder.set_sleep_impl(Some(sleep_impl));
814            let timeout_config = aws_smithy_types::timeout::TimeoutConfig::builder()
815                .connect_timeout(Duration::from_millis(self.sdk_connect_timeout_ms))
816                .read_timeout(Duration::from_millis(self.sdk_read_timeout_ms))
817                .operation_timeout(Duration::from_millis(self.sdk_operation_timeout_ms))
818                .operation_attempt_timeout(Duration::from_millis(
819                    self.sdk_operation_attempt_timeout_ms,
820                ))
821                .build();
822            builder.set_timeout_config(Some(timeout_config));
823
824            let retry_config = aws_smithy_types::retry::RetryConfig::standard()
825                .with_initial_backoff(Duration::from_millis(self.sdk_init_backoff_ms))
826                .with_max_backoff(Duration::from_millis(self.sdk_max_backoff_ms))
827                .with_max_attempts(self.sdk_max_retry_limit);
828            builder.set_retry_config(Some(retry_config));
829        }
830        if let Some(endpoint) = &config.endpoint {
831            builder = builder.endpoint_url(endpoint);
832        }
833        Ok(KinesisClient::from_conf(builder.build()))
834    }
835}
836
837/// Connection properties for NATS, used as a cache key for shared clients.
838/// This includes all properties that affect the connection itself (not stream/subject specific).
839#[derive(Debug, Clone, PartialEq, Eq, Hash)]
840pub struct NatsConnectionProps {
841    pub server_url: String,
842    pub connect_mode: String,
843    pub user: Option<String>,
844    pub password: Option<String>,
845    pub jwt: Option<String>,
846    pub nkey: Option<String>,
847}
848
849/// Shared NATS client cache.
850/// Client connections are cached as `Weak` pointers in the cache.
851/// NATS Connector can access this cache to reuse existing client connections,
852/// and avoid exhausting host machine ports.
853/// When reading from the cache, the connector should `upgrade` the weak pointer to an `Arc` reference.
854/// After all strong (Arc) references are dropped, the client connection will be cleaned up.
855/// Cache eviction naturally takes care of the dangling weak pointers.
856pub static SHARED_NATS_CLIENT: LazyLock<MokaCache<NatsConnectionProps, Weak<async_nats::Client>>> =
857    LazyLock::new(|| MokaCache::builder().build());
858
859#[serde_as]
860#[derive(Deserialize, Debug, Clone, WithOptions)]
861pub struct NatsCommon {
862    #[serde(rename = "server_url")]
863    pub server_url: String,
864    #[serde(rename = "subject")]
865    pub subject: String,
866    #[serde(rename = "connect_mode")]
867    pub connect_mode: String,
868    #[serde(rename = "username")]
869    pub user: Option<String>,
870    #[serde(rename = "password")]
871    pub password: Option<String>,
872    #[serde(rename = "jwt")]
873    pub jwt: Option<String>,
874    #[serde(rename = "nkey")]
875    pub nkey: Option<String>,
876    #[serde(rename = "max_bytes")]
877    #[serde_as(as = "Option<DisplayFromStr>")]
878    pub max_bytes: Option<i64>,
879    #[serde(rename = "max_messages")]
880    #[serde_as(as = "Option<DisplayFromStr>")]
881    pub max_messages: Option<i64>,
882    #[serde(rename = "max_messages_per_subject")]
883    #[serde_as(as = "Option<DisplayFromStr>")]
884    pub max_messages_per_subject: Option<i64>,
885    #[serde(rename = "max_consumers")]
886    #[serde_as(as = "Option<DisplayFromStr>")]
887    pub max_consumers: Option<i32>,
888    #[serde(rename = "max_message_size")]
889    #[serde_as(as = "Option<DisplayFromStr>")]
890    pub max_message_size: Option<i32>,
891    #[serde(rename = "allow_create_stream", default)]
892    #[serde_as(as = "DisplayFromStr")]
893    pub allow_create_stream: bool,
894}
895
896impl EnforceSecret for NatsCommon {
897    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
898        "password",
899        "jwt",
900        "nkey",
901    };
902}
903
904impl NatsCommon {
905    /// Extract connection properties that can be used as a cache key.
906    pub fn connection_props(&self) -> NatsConnectionProps {
907        NatsConnectionProps {
908            server_url: self.server_url.clone(),
909            connect_mode: self.connect_mode.clone(),
910            user: self.user.clone(),
911            password: self.password.clone(),
912            jwt: self.jwt.clone(),
913            nkey: self.nkey.clone(),
914        }
915    }
916
917    /// Build a new NATS client without caching.
918    async fn build_client_inner(&self) -> ConnectorResult<async_nats::Client> {
919        let mut connect_options = async_nats::ConnectOptions::new();
920        match self.connect_mode.as_str() {
921            "user_and_password" => {
922                if let (Some(v_user), Some(v_password)) =
923                    (self.user.as_ref(), self.password.as_ref())
924                {
925                    connect_options =
926                        connect_options.user_and_password(v_user.into(), v_password.into())
927                } else {
928                    bail!("nats connect mode is user_and_password, but user or password is empty");
929                }
930            }
931
932            "credential" => {
933                if let (Some(v_nkey), Some(v_jwt)) = (self.nkey.as_ref(), self.jwt.as_ref()) {
934                    connect_options = connect_options
935                        .credentials(&self.create_credential(v_nkey, v_jwt)?)
936                        .expect("failed to parse static creds")
937                } else {
938                    bail!("nats connect mode is credential, but nkey or jwt is empty");
939                }
940            }
941            "plain" => {}
942            _ => {
943                bail!("nats connect mode only accepts user_and_password/credential/plain");
944            }
945        };
946
947        let servers = self.server_url.split(',').collect::<Vec<&str>>();
948        let client = connect_options
949            .connect(
950                servers
951                    .iter()
952                    .map(|url| url.parse())
953                    .collect::<Result<Vec<async_nats::ServerAddr>, _>>()?,
954            )
955            .await
956            .context("build nats client error")
957            .map_err(SinkError::Nats)?;
958        Ok(client)
959    }
960
961    /// Build a NATS client, attempting to reuse an existing cached client if available.
962    /// See `SHARED_NATS_CLIENT` for more details.
963    pub(crate) async fn build_client(&self) -> ConnectorResult<Arc<async_nats::Client>> {
964        let connection_props = self.connection_props();
965        let mut client: Option<Arc<async_nats::Client>> = None;
966
967        SHARED_NATS_CLIENT
968            .entry_by_ref(&connection_props)
969            .and_try_compute_with::<_, _, crate::error::ConnectorError>(|maybe_entry| async {
970                if let Some(entry) = maybe_entry
971                    && let entry_value = entry.into_value()
972                    && let Some(existing_client) = entry_value.upgrade()
973                {
974                    match existing_client.connection_state() {
975                        async_nats::connection::State::Connected => {
976                            tracing::info!("reuse existing nats client for {}", self.server_url);
977                            client = Some(existing_client);
978                            return Ok(Op::Nop);
979                        }
980                        _ => {
981                            tracing::warn!(
982                                server_url = self.server_url,
983                                "existing nats client is not connected",
984                            );
985                        }
986                    }
987                }
988                tracing::info!(
989                    server_url = self.server_url,
990                    "no cached client, or client disconnected, building new nats client"
991                );
992                let new_client = Arc::new(self.build_client_inner().await?);
993                client = Some(new_client.clone());
994                Ok(Op::Put(Arc::downgrade(&new_client)))
995            })
996            .await?;
997
998        Ok(client.expect("client should be set"))
999    }
1000
1001    pub(crate) async fn build_context(&self) -> ConnectorResult<jetstream::Context> {
1002        let client = self.build_client().await?;
1003        let jetstream = async_nats::jetstream::new((*client).clone());
1004        Ok(jetstream)
1005    }
1006
1007    /// Build a `JetStream` context using a pre-existing client.
1008    pub(crate) fn build_context_from_client(
1009        client: &Arc<async_nats::Client>,
1010    ) -> jetstream::Context {
1011        async_nats::jetstream::new((**client).clone())
1012    }
1013
1014    /// Build a NATS `JetStream` consumer.
1015    ///
1016    /// If `existing_client` is provided, it will be used instead of creating/fetching a new one.
1017    /// This allows callers to reuse a client they already hold.
1018    pub(crate) async fn build_consumer(
1019        &self,
1020        stream: String,
1021        durable_consumer_name: String,
1022        split_id: String,
1023        start_sequence: NatsOffset,
1024        mut config: jetstream::consumer::pull::Config,
1025        existing_client: Option<Arc<async_nats::Client>>,
1026    ) -> ConnectorResult<(
1027        async_nats::jetstream::consumer::Consumer<async_nats::jetstream::consumer::pull::Config>,
1028        Arc<async_nats::Client>,
1029    )> {
1030        let client = match existing_client {
1031            Some(c) => c,
1032            None => self.build_client().await?,
1033        };
1034        let context = Self::build_context_from_client(&client);
1035        let stream = self.build_or_get_stream(context.clone(), stream).await?;
1036        let subject_name = self
1037            .subject
1038            .replace(',', "-")
1039            .replace(['.', '>', '*', ' ', '\t'], "_");
1040        let name = format!("risingwave-consumer-{}-{}", subject_name, split_id);
1041
1042        let deliver_policy = match start_sequence {
1043            NatsOffset::Earliest => DeliverPolicy::All,
1044            NatsOffset::Latest => DeliverPolicy::New,
1045            NatsOffset::SequenceNumber(v) => {
1046                // for compatibility, we do not write to any state table now
1047                let parsed = v
1048                    .parse::<u64>()
1049                    .context("failed to parse nats offset as sequence number")?;
1050                DeliverPolicy::ByStartSequence {
1051                    start_sequence: 1 + parsed,
1052                }
1053            }
1054            NatsOffset::Timestamp(v) => DeliverPolicy::ByStartTime {
1055                start_time: OffsetDateTime::from_unix_timestamp_nanos(v as i128 * 1_000_000)
1056                    .context("invalid timestamp for nats offset")?,
1057            },
1058            NatsOffset::None => DeliverPolicy::All,
1059        };
1060
1061        let consumer = match stream.get_consumer(&name).await {
1062            Ok(consumer) => consumer,
1063            _ => {
1064                stream
1065                    .get_or_create_consumer(&name, {
1066                        config.deliver_policy = deliver_policy;
1067                        config.durable_name = Some(durable_consumer_name);
1068                        config.filter_subjects =
1069                            self.subject.split(',').map(|s| s.to_owned()).collect();
1070                        config
1071                    })
1072                    .await?
1073            }
1074        };
1075        Ok((consumer, client))
1076    }
1077
1078    pub(crate) async fn build_or_get_stream(
1079        &self,
1080        jetstream: jetstream::Context,
1081        stream_str: String,
1082    ) -> ConnectorResult<jetstream::stream::Stream> {
1083        let subjects: Vec<String> = self.subject.split(',').map(|s| s.to_owned()).collect();
1084
1085        // In `SourceEnumerator`, we may create a stream
1086        // In `SourceReader`, the desired stream MUST exist
1087        if let Ok(mut stream_instance) = jetstream.get_stream(&stream_str).await {
1088            tracing::info!(
1089                "load existing nats stream ({:?}) with config {:?}",
1090                stream_str,
1091                stream_instance.info().await?
1092            );
1093            return Ok(stream_instance);
1094        }
1095
1096        if !self.allow_create_stream {
1097            return Err(anyhow!(
1098                "stream {} not found, set `allow_create_stream` to true to create a stream",
1099                stream_str
1100            )
1101            .into());
1102        }
1103
1104        let mut config = jetstream::stream::Config {
1105            name: stream_str.clone(),
1106            max_bytes: 1000000,
1107            subjects,
1108            ..Default::default()
1109        };
1110        if let Some(v) = self.max_bytes {
1111            config.max_bytes = v;
1112        }
1113        if let Some(v) = self.max_messages {
1114            config.max_messages = v;
1115        }
1116        if let Some(v) = self.max_messages_per_subject {
1117            config.max_messages_per_subject = v;
1118        }
1119        if let Some(v) = self.max_consumers {
1120            config.max_consumers = v;
1121        }
1122        if let Some(v) = self.max_message_size {
1123            config.max_message_size = v;
1124        }
1125        tracing::info!(
1126            "create nats stream ({:?}) with config {:?}",
1127            &stream_str,
1128            config
1129        );
1130        let stream = jetstream.get_or_create_stream(config).await?;
1131        Ok(stream)
1132    }
1133
1134    pub(crate) fn create_credential(&self, seed: &str, jwt: &str) -> ConnectorResult<String> {
1135        let creds = format!(
1136            "-----BEGIN NATS USER JWT-----\n{}\n------END NATS USER JWT------\n\n\
1137                         ************************* IMPORTANT *************************\n\
1138                         NKEY Seed printed below can be used to sign and prove identity.\n\
1139                         NKEYs are sensitive and should be treated as secrets.\n\n\
1140                         -----BEGIN USER NKEY SEED-----\n{}\n------END USER NKEY SEED------\n\n\
1141                         *************************************************************",
1142            jwt, seed
1143        );
1144        Ok(creds)
1145    }
1146}
1147
1148pub(crate) fn load_certs(
1149    certificates: &str,
1150) -> ConnectorResult<Vec<rustls_pki_types::CertificateDer<'static>>> {
1151    let cert_bytes = if let Some(path) = certificates.strip_prefix("fs://") {
1152        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
1153    } else {
1154        certificates.as_bytes().to_owned()
1155    };
1156
1157    CertificateDer::pem_slice_iter(&cert_bytes)
1158        .collect::<Result<Vec<_>, _>>()
1159        .context("Failed to parse certificates")
1160        .map_err(Into::into)
1161}
1162
1163pub(crate) fn load_private_key(
1164    certificate: &str,
1165) -> ConnectorResult<rustls_pki_types::PrivateKeyDer<'static>> {
1166    let cert_bytes = if let Some(path) = certificate.strip_prefix("fs://") {
1167        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
1168    } else {
1169        certificate.as_bytes().to_owned()
1170    };
1171
1172    let cert = PrivatePkcs8KeyDer::pem_slice_iter(&cert_bytes)
1173        .next()
1174        .ok_or_else(|| anyhow!("No private key found"))?
1175        .context("Failed to parse private key")?;
1176    Ok(cert.into())
1177}
1178
1179#[serde_as]
1180#[derive(Deserialize, Debug, Clone, WithOptions)]
1181pub struct MongodbCommon {
1182    /// The URL of `MongoDB`
1183    #[serde(rename = "mongodb.url")]
1184    pub connect_uri: String,
1185    /// The collection name where data should be written to or read from. For sinks, the format is
1186    /// `db_name.collection_name`. Data can also be written to dynamic collections, see `collection.name.field`
1187    /// for more information.
1188    #[serde(rename = "collection.name")]
1189    pub collection_name: String,
1190}
1191
1192impl EnforceSecret for MongodbCommon {
1193    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
1194        "mongodb.url"
1195    };
1196}
1197
1198impl MongodbCommon {
1199    pub(crate) async fn build_client(&self) -> ConnectorResult<mongodb::Client> {
1200        let client = mongodb::Client::with_uri_str(&self.connect_uri).await?;
1201
1202        Ok(client)
1203    }
1204}