risingwave_connector/connector_common/
common.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::hash::Hash;
17use std::io::Write;
18use std::path::Path;
19use std::time::Duration;
20
21use anyhow::{Context, anyhow};
22use async_nats::jetstream::consumer::DeliverPolicy;
23use async_nats::jetstream::{self};
24use aws_sdk_kinesis::Client as KinesisClient;
25use aws_sdk_kinesis::config::{AsyncSleep, SharedAsyncSleep, Sleep};
26use phf::{Set, phf_set};
27use pulsar::authentication::oauth2::{OAuth2Authentication, OAuth2Params};
28use pulsar::{Authentication, Pulsar, TokioExecutor};
29use rdkafka::ClientConfig;
30use risingwave_common::bail;
31use serde_derive::Deserialize;
32use serde_with::json::JsonString;
33use serde_with::{DisplayFromStr, serde_as};
34use tempfile::NamedTempFile;
35use time::OffsetDateTime;
36use url::Url;
37use with_options::WithOptions;
38
39use crate::aws_utils::load_file_descriptor_from_s3;
40use crate::deserialize_duration_from_string;
41use crate::enforce_secret::EnforceSecret;
42use crate::error::ConnectorResult;
43use crate::sink::SinkError;
44use crate::source::nats::source::NatsOffset;
45
46pub const PRIVATE_LINK_BROKER_REWRITE_MAP_KEY: &str = "broker.rewrite.endpoints";
47pub const PRIVATE_LINK_TARGETS_KEY: &str = "privatelink.targets";
48
49const AWS_MSK_IAM_AUTH: &str = "AWS_MSK_IAM";
50
51/// The environment variable to disable using default credential from environment.
52/// It's recommended to set this variable to `true` in cloud hosting environment.
53pub const DISABLE_DEFAULT_CREDENTIAL: &str = "DISABLE_DEFAULT_CREDENTIAL";
54
55#[derive(Debug, Clone, Deserialize)]
56pub struct AwsPrivateLinkItem {
57    pub az_id: Option<String>,
58    pub port: u16,
59}
60
61use aws_config::default_provider::region::DefaultRegionChain;
62use aws_config::sts::AssumeRoleProvider;
63use aws_credential_types::provider::SharedCredentialsProvider;
64use aws_types::SdkConfig;
65use aws_types::region::Region;
66use risingwave_common::util::env_var::env_var_is_true;
67
68/// A flatten config map for aws auth.
69#[derive(Deserialize, Debug, Clone, WithOptions, PartialEq)]
70pub struct AwsAuthProps {
71    #[serde(rename = "aws.region", alias = "region", alias = "s3.region")]
72    pub region: Option<String>,
73
74    #[serde(
75        rename = "aws.endpoint_url",
76        alias = "endpoint_url",
77        alias = "endpoint",
78        alias = "s3.endpoint"
79    )]
80    pub endpoint: Option<String>,
81    #[serde(
82        rename = "aws.credentials.access_key_id",
83        alias = "access_key",
84        alias = "s3.access.key"
85    )]
86    pub access_key: Option<String>,
87    #[serde(
88        rename = "aws.credentials.secret_access_key",
89        alias = "secret_key",
90        alias = "s3.secret.key"
91    )]
92    pub secret_key: Option<String>,
93    #[serde(rename = "aws.credentials.session_token", alias = "session_token")]
94    pub session_token: Option<String>,
95    /// IAM role
96    #[serde(rename = "aws.credentials.role.arn", alias = "arn")]
97    pub arn: Option<String>,
98    /// external ID in IAM role trust policy
99    #[serde(rename = "aws.credentials.role.external_id", alias = "external_id")]
100    pub external_id: Option<String>,
101    #[serde(rename = "aws.profile", alias = "profile")]
102    pub profile: Option<String>,
103    #[serde(rename = "aws.msk.signer_timeout_sec")]
104    pub msk_signer_timeout_sec: Option<u64>,
105}
106
107impl EnforceSecret for AwsAuthProps {
108    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
109        "access_key",
110        "aws.credentials.access_key_id",
111        "s3.access.key",
112        "secret_key",
113        "aws.credentials.secret_access_key",
114        "s3.secret.key",
115        "session_token",
116        "aws.credentials.session_token",
117    };
118}
119
120impl AwsAuthProps {
121    async fn build_region(&self) -> ConnectorResult<Region> {
122        if let Some(region_name) = &self.region {
123            Ok(Region::new(region_name.clone()))
124        } else {
125            let mut region_chain = DefaultRegionChain::builder();
126            if let Some(profile_name) = &self.profile {
127                region_chain = region_chain.profile_name(profile_name);
128            }
129
130            Ok(region_chain
131                .build()
132                .region()
133                .await
134                .context("region should be provided")?)
135        }
136    }
137
138    async fn build_credential_provider(&self) -> ConnectorResult<SharedCredentialsProvider> {
139        if self.access_key.is_some() && self.secret_key.is_some() {
140            Ok(SharedCredentialsProvider::new(
141                aws_credential_types::Credentials::from_keys(
142                    self.access_key.as_ref().unwrap(),
143                    self.secret_key.as_ref().unwrap(),
144                    self.session_token.clone(),
145                ),
146            ))
147        } else if !env_var_is_true(DISABLE_DEFAULT_CREDENTIAL) {
148            Ok(SharedCredentialsProvider::new(
149                aws_config::default_provider::credentials::default_provider().await,
150            ))
151        } else {
152            bail!("Both \"access_key\" and \"secret_key\" are required.")
153        }
154    }
155
156    async fn with_role_provider(
157        &self,
158        credential: SharedCredentialsProvider,
159    ) -> ConnectorResult<SharedCredentialsProvider> {
160        if let Some(role_name) = &self.arn {
161            let region = self.build_region().await?;
162            let mut role = AssumeRoleProvider::builder(role_name)
163                .session_name("RisingWave")
164                .region(region);
165            if let Some(id) = &self.external_id {
166                role = role.external_id(id);
167            }
168            let provider = role.build_from_provider(credential).await;
169            Ok(SharedCredentialsProvider::new(provider))
170        } else {
171            Ok(credential)
172        }
173    }
174
175    pub async fn build_config(&self) -> ConnectorResult<SdkConfig> {
176        let region = self.build_region().await?;
177        let credentials_provider = self
178            .with_role_provider(self.build_credential_provider().await?)
179            .await?;
180        let mut config_loader = aws_config::from_env()
181            .region(region)
182            .credentials_provider(credentials_provider);
183
184        if let Some(endpoint) = self.endpoint.as_ref() {
185            config_loader = config_loader.endpoint_url(endpoint);
186        }
187
188        Ok(config_loader.load().await)
189    }
190}
191
192#[serde_as]
193#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
194pub struct KafkaConnectionProps {
195    #[serde(rename = "properties.bootstrap.server", alias = "kafka.brokers")]
196    pub brokers: String,
197
198    /// Security protocol used for RisingWave to communicate with Kafka brokers. Could be
199    /// PLAINTEXT, SSL, `SASL_PLAINTEXT` or `SASL_SSL`.
200    #[serde(rename = "properties.security.protocol")]
201    #[with_option(allow_alter_on_fly)]
202    security_protocol: Option<String>,
203
204    #[serde(rename = "properties.ssl.endpoint.identification.algorithm")]
205    #[with_option(allow_alter_on_fly)]
206    ssl_endpoint_identification_algorithm: Option<String>,
207
208    // For the properties below, please refer to [librdkafka](https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md) for more information.
209    /// Path to CA certificate file for verifying the broker's key.
210    #[serde(rename = "properties.ssl.ca.location")]
211    ssl_ca_location: Option<String>,
212
213    /// CA certificate string (PEM format) for verifying the broker's key.
214    #[serde(rename = "properties.ssl.ca.pem")]
215    ssl_ca_pem: Option<String>,
216
217    /// Path to client's certificate file (PEM).
218    #[serde(rename = "properties.ssl.certificate.location")]
219    ssl_certificate_location: Option<String>,
220
221    /// Client's public key string (PEM format) used for authentication.
222    #[serde(rename = "properties.ssl.certificate.pem")]
223    ssl_certificate_pem: Option<String>,
224
225    /// Path to client's private key file (PEM).
226    #[serde(rename = "properties.ssl.key.location")]
227    ssl_key_location: Option<String>,
228
229    /// Client's private key string (PEM format) used for authentication.
230    #[serde(rename = "properties.ssl.key.pem")]
231    ssl_key_pem: Option<String>,
232
233    /// Passphrase of client's private key.
234    #[serde(rename = "properties.ssl.key.password")]
235    ssl_key_password: Option<String>,
236
237    /// SASL mechanism if SASL is enabled. Currently support PLAIN, SCRAM, GSSAPI, and `AWS_MSK_IAM`.
238    #[serde(rename = "properties.sasl.mechanism")]
239    #[with_option(allow_alter_on_fly)]
240    sasl_mechanism: Option<String>,
241
242    /// SASL username for SASL/PLAIN and SASL/SCRAM.
243    #[serde(rename = "properties.sasl.username")]
244    #[with_option(allow_alter_on_fly)]
245    sasl_username: Option<String>,
246
247    /// SASL password for SASL/PLAIN and SASL/SCRAM.
248    #[serde(rename = "properties.sasl.password")]
249    #[with_option(allow_alter_on_fly)]
250    sasl_password: Option<String>,
251
252    /// Kafka server's Kerberos principal name under SASL/GSSAPI, not including /hostname@REALM.
253    #[serde(rename = "properties.sasl.kerberos.service.name")]
254    sasl_kerberos_service_name: Option<String>,
255
256    /// Path to client's Kerberos keytab file under SASL/GSSAPI.
257    #[serde(rename = "properties.sasl.kerberos.keytab")]
258    sasl_kerberos_keytab: Option<String>,
259
260    /// Client's Kerberos principal name under SASL/GSSAPI.
261    #[serde(rename = "properties.sasl.kerberos.principal")]
262    sasl_kerberos_principal: Option<String>,
263
264    /// Shell command to refresh or acquire the client's Kerberos ticket under SASL/GSSAPI.
265    #[serde(rename = "properties.sasl.kerberos.kinit.cmd")]
266    sasl_kerberos_kinit_cmd: Option<String>,
267
268    /// Minimum time in milliseconds between key refresh attempts under SASL/GSSAPI.
269    #[serde(rename = "properties.sasl.kerberos.min.time.before.relogin")]
270    sasl_kerberos_min_time_before_relogin: Option<String>,
271
272    /// Configurations for SASL/OAUTHBEARER.
273    #[serde(rename = "properties.sasl.oauthbearer.config")]
274    sasl_oathbearer_config: Option<String>,
275}
276
277impl EnforceSecret for KafkaConnectionProps {
278    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
279        "properties.ssl.key.pem",
280        "properties.ssl.key.password",
281        "properties.sasl.password",
282    };
283}
284
285#[serde_as]
286#[derive(Debug, Clone, Deserialize, WithOptions)]
287pub struct KafkaCommon {
288    // connection related props are moved to `KafkaConnection`
289    #[serde(rename = "topic", alias = "kafka.topic")]
290    pub topic: String,
291
292    #[serde(
293        rename = "properties.sync.call.timeout",
294        deserialize_with = "deserialize_duration_from_string",
295        default = "default_kafka_sync_call_timeout"
296    )]
297    #[with_option(allow_alter_on_fly)]
298    pub sync_call_timeout: Duration,
299}
300
301#[serde_as]
302#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
303pub struct KafkaPrivateLinkCommon {
304    /// This is generated from `private_link_targets` and `private_link_endpoint` in frontend, instead of given by users.
305    #[serde(rename = "broker.rewrite.endpoints")]
306    #[serde_as(as = "Option<JsonString>")]
307    pub broker_rewrite_map: Option<BTreeMap<String, String>>,
308}
309
310const fn default_kafka_sync_call_timeout() -> Duration {
311    Duration::from_secs(5)
312}
313
314const fn default_socket_keepalive_enable() -> bool {
315    true
316}
317
318#[serde_as]
319#[derive(Debug, Clone, Deserialize, WithOptions)]
320pub struct RdKafkaPropertiesCommon {
321    /// Maximum Kafka protocol request message size. Due to differing framing overhead between
322    /// protocol versions the producer is unable to reliably enforce a strict max message limit at
323    /// produce time and may exceed the maximum size by one message in protocol `ProduceRequests`,
324    /// the broker will enforce the topic's max.message.bytes limit
325    #[serde(rename = "properties.message.max.bytes")]
326    #[serde_as(as = "Option<DisplayFromStr>")]
327    #[with_option(allow_alter_on_fly)]
328    pub message_max_bytes: Option<usize>,
329
330    /// Maximum Kafka protocol response message size. This serves as a safety precaution to avoid
331    /// memory exhaustion in case of protocol hickups. This value must be at least fetch.max.bytes
332    /// + 512 to allow for protocol overhead; the value is adjusted automatically unless the
333    /// configuration property is explicitly set.
334    #[serde(rename = "properties.receive.message.max.bytes")]
335    #[serde_as(as = "Option<DisplayFromStr>")]
336    #[with_option(allow_alter_on_fly)]
337    pub receive_message_max_bytes: Option<usize>,
338
339    #[serde(rename = "properties.statistics.interval.ms")]
340    #[serde_as(as = "Option<DisplayFromStr>")]
341    #[with_option(allow_alter_on_fly)]
342    pub statistics_interval_ms: Option<usize>,
343
344    /// Client identifier
345    #[serde(rename = "properties.client.id")]
346    #[serde_as(as = "Option<DisplayFromStr>")]
347    #[with_option(allow_alter_on_fly)]
348    pub client_id: Option<String>,
349
350    #[serde(rename = "properties.enable.ssl.certificate.verification")]
351    #[serde_as(as = "Option<DisplayFromStr>")]
352    #[with_option(allow_alter_on_fly)]
353    pub enable_ssl_certificate_verification: Option<bool>,
354
355    #[serde(
356        rename = "properties.socket.keepalive.enable",
357        default = "default_socket_keepalive_enable"
358    )]
359    #[serde_as(as = "DisplayFromStr")]
360    pub socket_keepalive_enable: bool,
361}
362
363impl RdKafkaPropertiesCommon {
364    pub(crate) fn set_client(&self, c: &mut rdkafka::ClientConfig) {
365        if let Some(v) = self.statistics_interval_ms {
366            c.set("statistics.interval.ms", v.to_string());
367        }
368        if let Some(v) = self.message_max_bytes {
369            c.set("message.max.bytes", v.to_string());
370        }
371        if let Some(v) = self.receive_message_max_bytes {
372            c.set("receive.message.max.bytes", v.to_string());
373        }
374        if let Some(v) = self.client_id.as_ref() {
375            c.set("client.id", v);
376        }
377        if let Some(v) = self.enable_ssl_certificate_verification {
378            c.set("enable.ssl.certificate.verification", v.to_string());
379        }
380        c.set(
381            "socket.keepalive.enable",
382            self.socket_keepalive_enable.to_string(),
383        );
384    }
385}
386
387impl KafkaConnectionProps {
388    #[cfg(test)]
389    pub fn test_default() -> Self {
390        Self {
391            brokers: "localhost:9092".to_owned(),
392            security_protocol: None,
393            ssl_ca_location: None,
394            ssl_certificate_location: None,
395            ssl_key_location: None,
396            ssl_ca_pem: None,
397            ssl_certificate_pem: None,
398            ssl_key_pem: None,
399            ssl_key_password: None,
400            ssl_endpoint_identification_algorithm: None,
401            sasl_mechanism: None,
402            sasl_username: None,
403            sasl_password: None,
404            sasl_kerberos_service_name: None,
405            sasl_kerberos_keytab: None,
406            sasl_kerberos_principal: None,
407            sasl_kerberos_kinit_cmd: None,
408            sasl_kerberos_min_time_before_relogin: None,
409            sasl_oathbearer_config: None,
410        }
411    }
412
413    pub(crate) fn set_security_properties(&self, config: &mut ClientConfig) {
414        // AWS_MSK_IAM
415        if self.is_aws_msk_iam() {
416            config.set("security.protocol", "SASL_SSL");
417            config.set("sasl.mechanism", "OAUTHBEARER");
418            return;
419        }
420
421        // Security protocol
422        if let Some(security_protocol) = self.security_protocol.as_ref() {
423            config.set("security.protocol", security_protocol);
424        }
425
426        // SSL
427        if let Some(ssl_ca_location) = self.ssl_ca_location.as_ref() {
428            config.set("ssl.ca.location", ssl_ca_location);
429        }
430        if let Some(ssl_ca_pem) = self.ssl_ca_pem.as_ref() {
431            config.set("ssl.ca.pem", ssl_ca_pem);
432        }
433        if let Some(ssl_certificate_location) = self.ssl_certificate_location.as_ref() {
434            config.set("ssl.certificate.location", ssl_certificate_location);
435        }
436        if let Some(ssl_certificate_pem) = self.ssl_certificate_pem.as_ref() {
437            config.set("ssl.certificate.pem", ssl_certificate_pem);
438        }
439        if let Some(ssl_key_location) = self.ssl_key_location.as_ref() {
440            config.set("ssl.key.location", ssl_key_location);
441        }
442        if let Some(ssl_key_pem) = self.ssl_key_pem.as_ref() {
443            config.set("ssl.key.pem", ssl_key_pem);
444        }
445        if let Some(ssl_key_password) = self.ssl_key_password.as_ref() {
446            config.set("ssl.key.password", ssl_key_password);
447        }
448        if let Some(ssl_endpoint_identification_algorithm) =
449            self.ssl_endpoint_identification_algorithm.as_ref()
450        {
451            // accept only `none` and `http` here, let the sdk do the check
452            config.set(
453                "ssl.endpoint.identification.algorithm",
454                ssl_endpoint_identification_algorithm,
455            );
456        }
457
458        // SASL mechanism
459        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref() {
460            config.set("sasl.mechanism", sasl_mechanism);
461        }
462
463        // SASL/PLAIN & SASL/SCRAM
464        if let Some(sasl_username) = self.sasl_username.as_ref() {
465            config.set("sasl.username", sasl_username);
466        }
467        if let Some(sasl_password) = self.sasl_password.as_ref() {
468            config.set("sasl.password", sasl_password);
469        }
470
471        // SASL/GSSAPI
472        if let Some(sasl_kerberos_service_name) = self.sasl_kerberos_service_name.as_ref() {
473            config.set("sasl.kerberos.service.name", sasl_kerberos_service_name);
474        }
475        if let Some(sasl_kerberos_keytab) = self.sasl_kerberos_keytab.as_ref() {
476            config.set("sasl.kerberos.keytab", sasl_kerberos_keytab);
477        }
478        if let Some(sasl_kerberos_principal) = self.sasl_kerberos_principal.as_ref() {
479            config.set("sasl.kerberos.principal", sasl_kerberos_principal);
480        }
481        if let Some(sasl_kerberos_kinit_cmd) = self.sasl_kerberos_kinit_cmd.as_ref() {
482            config.set("sasl.kerberos.kinit.cmd", sasl_kerberos_kinit_cmd);
483        }
484        if let Some(sasl_kerberos_min_time_before_relogin) =
485            self.sasl_kerberos_min_time_before_relogin.as_ref()
486        {
487            config.set(
488                "sasl.kerberos.min.time.before.relogin",
489                sasl_kerberos_min_time_before_relogin,
490            );
491        }
492
493        // SASL/OAUTHBEARER
494        if let Some(sasl_oathbearer_config) = self.sasl_oathbearer_config.as_ref() {
495            config.set("sasl.oauthbearer.config", sasl_oathbearer_config);
496        }
497        // Currently, we only support unsecured OAUTH.
498        config.set("enable.sasl.oauthbearer.unsecure.jwt", "true");
499    }
500
501    pub(crate) fn is_aws_msk_iam(&self) -> bool {
502        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref()
503            && sasl_mechanism == AWS_MSK_IAM_AUTH
504        {
505            true
506        } else {
507            false
508        }
509    }
510}
511
512#[derive(Clone, Debug, Deserialize, WithOptions)]
513pub struct PulsarCommon {
514    #[serde(rename = "topic", alias = "pulsar.topic")]
515    pub topic: String,
516
517    #[serde(rename = "service.url", alias = "pulsar.service.url")]
518    pub service_url: String,
519
520    #[serde(rename = "auth.token")]
521    pub auth_token: Option<String>,
522}
523
524impl EnforceSecret for PulsarCommon {
525    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
526        "pulsar.auth.token",
527    };
528}
529
530#[derive(Clone, Debug, Deserialize, WithOptions)]
531pub struct PulsarOauthCommon {
532    #[serde(rename = "oauth.issuer.url")]
533    pub issuer_url: String,
534
535    #[serde(rename = "oauth.credentials.url")]
536    pub credentials_url: String,
537
538    #[serde(rename = "oauth.audience")]
539    pub audience: String,
540
541    #[serde(rename = "oauth.scope")]
542    pub scope: Option<String>,
543}
544
545fn create_credential_temp_file(credentials: &[u8]) -> std::io::Result<NamedTempFile> {
546    let mut f = NamedTempFile::new()?;
547    f.write_all(credentials)?;
548    f.as_file().sync_all()?;
549    Ok(f)
550}
551
552impl PulsarCommon {
553    pub(crate) async fn build_client(
554        &self,
555        oauth: &Option<PulsarOauthCommon>,
556        aws_auth_props: &AwsAuthProps,
557    ) -> ConnectorResult<Pulsar<TokioExecutor>> {
558        let mut pulsar_builder = Pulsar::builder(&self.service_url, TokioExecutor);
559        let mut _temp_file = None; // Keep temp file alive
560
561        if let Some(oauth) = oauth.as_ref() {
562            let (credentials_url, temp_file) = self
563                .resolve_pulsar_credentials_url(oauth, aws_auth_props)
564                .await?;
565            _temp_file = temp_file;
566
567            let auth_params = OAuth2Params {
568                issuer_url: oauth.issuer_url.clone(),
569                credentials_url,
570                audience: Some(oauth.audience.clone()),
571                scope: oauth.scope.clone(),
572            };
573
574            pulsar_builder = pulsar_builder
575                .with_auth_provider(OAuth2Authentication::client_credentials(auth_params));
576        } else if let Some(auth_token) = &self.auth_token {
577            pulsar_builder = pulsar_builder.with_auth(Authentication {
578                name: "token".to_owned(),
579                data: Vec::from(auth_token.as_str()),
580            });
581        }
582
583        let res = pulsar_builder.build().await.map_err(|e| anyhow!(e))?;
584        drop(_temp_file); // Explicitly drop temp file after client is built
585        Ok(res)
586    }
587
588    pub(crate) async fn resolve_pulsar_credentials_url(
589        &self,
590        oauth: &PulsarOauthCommon,
591        aws_auth_props: &AwsAuthProps,
592    ) -> ConnectorResult<(String, Option<NamedTempFile>)> {
593        // Try parsing as URL first
594        if let Ok(url) = Url::parse(&oauth.credentials_url) {
595            return self
596                .handle_pulsar_credentials_url(&url, aws_auth_props)
597                .await;
598        }
599
600        // If not a valid URL, check if it's an absolute file path
601        let path = Path::new(&oauth.credentials_url);
602        if !path.is_absolute() {
603            bail!("credentials_url must be a valid URL (s3://, file://) or an absolute file path");
604        }
605
606        // Verify the file exists
607        if !tokio::fs::try_exists(&oauth.credentials_url)
608            .await
609            .unwrap_or(false)
610        {
611            bail!("credentials file does not exist: {}", oauth.credentials_url);
612        }
613
614        // Return absolute path with file:// prefix
615        Ok((format!("file://{}", oauth.credentials_url), None))
616    }
617
618    pub(crate) async fn handle_pulsar_credentials_url(
619        &self,
620        url: &Url,
621        aws_auth_props: &AwsAuthProps,
622    ) -> ConnectorResult<(String, Option<NamedTempFile>)> {
623        match url.scheme() {
624            "s3" => {
625                let credentials = load_file_descriptor_from_s3(url, aws_auth_props).await?;
626                let temp_file = create_credential_temp_file(&credentials)
627                    .context("failed to create temp file for pulsar credentials")?;
628
629                let temp_path = temp_file
630                    .path()
631                    .to_str()
632                    .context("temp file path is not valid UTF-8")?;
633
634                Ok((format!("file://{}", temp_path), Some(temp_file)))
635            }
636            "file" => Ok((url.to_string(), None)),
637            _ => bail!(
638                "invalid credentials_url scheme '{}', only file://, s3://, and absolute file paths are supported",
639                url.scheme()
640            ),
641        }
642    }
643}
644
645#[serde_as]
646#[derive(Deserialize, Debug, Clone, WithOptions)]
647pub struct KinesisCommon {
648    #[serde(rename = "stream", alias = "kinesis.stream.name")]
649    pub stream_name: String,
650    #[serde(rename = "aws.region", alias = "kinesis.stream.region")]
651    pub stream_region: String,
652    #[serde(rename = "endpoint", alias = "kinesis.endpoint")]
653    pub endpoint: Option<String>,
654    #[serde(
655        rename = "aws.credentials.access_key_id",
656        alias = "kinesis.credentials.access"
657    )]
658    pub credentials_access_key: Option<String>,
659    #[serde(
660        rename = "aws.credentials.secret_access_key",
661        alias = "kinesis.credentials.secret"
662    )]
663    pub credentials_secret_access_key: Option<String>,
664    #[serde(
665        rename = "aws.credentials.session_token",
666        alias = "kinesis.credentials.session_token"
667    )]
668    pub session_token: Option<String>,
669    #[serde(rename = "aws.credentials.role.arn", alias = "kinesis.assumerole.arn")]
670    pub assume_role_arn: Option<String>,
671    #[serde(
672        rename = "aws.credentials.role.external_id",
673        alias = "kinesis.assumerole.external_id"
674    )]
675    pub assume_role_external_id: Option<String>,
676
677    // sdk options
678    #[serde(
679        rename = "kinesis.sdk.connect_timeout_ms",
680        default = "kinesis_default_connect_timeout_ms"
681    )]
682    #[serde_as(as = "DisplayFromStr")]
683    pub sdk_connect_timeout_ms: u64,
684
685    #[serde(
686        rename = "kinesis.sdk.read_timeout_ms",
687        default = "kinesis_default_read_timeout_ms"
688    )]
689    #[serde_as(as = "DisplayFromStr")]
690    pub sdk_read_timeout_ms: u64,
691
692    #[serde(
693        rename = "kinesis.sdk.operation_timeout_ms",
694        default = "kinesis_default_operation_timeout_ms"
695    )]
696    #[serde_as(as = "DisplayFromStr")]
697    pub sdk_operation_timeout_ms: u64,
698
699    #[serde(
700        rename = "kinesis.sdk.operation_attempt_timeout_ms",
701        default = "kinesis_default_operation_attempt_timeout_ms"
702    )]
703    #[serde_as(as = "DisplayFromStr")]
704    pub sdk_operation_attempt_timeout_ms: u64,
705
706    #[serde(
707        rename = "kinesis.sdk.max_retry_limit",
708        default = "kinesis_default_max_retry_limit"
709    )]
710    #[serde_as(as = "DisplayFromStr")]
711    pub sdk_max_retry_limit: u32,
712
713    #[serde(
714        rename = "kinesis.sdk.init_backoff_ms",
715        default = "kinesis_default_init_backoff_ms"
716    )]
717    #[serde_as(as = "DisplayFromStr")]
718    pub sdk_init_backoff_ms: u64,
719
720    #[serde(
721        rename = "kinesis.sdk.max_backoff_ms",
722        default = "kinesis_default_max_backoff_ms"
723    )]
724    #[serde_as(as = "DisplayFromStr")]
725    pub sdk_max_backoff_ms: u64,
726}
727
728#[derive(Debug)]
729pub struct KinesisAsyncSleepImpl;
730
731impl AsyncSleep for KinesisAsyncSleepImpl {
732    fn sleep(&self, duration: Duration) -> Sleep {
733        Sleep::new(async move { tokio::time::sleep(duration).await })
734    }
735}
736
737const fn kinesis_default_connect_timeout_ms() -> u64 {
738    10000
739}
740
741const fn kinesis_default_read_timeout_ms() -> u64 {
742    10000
743}
744
745const fn kinesis_default_operation_timeout_ms() -> u64 {
746    10000
747}
748
749const fn kinesis_default_operation_attempt_timeout_ms() -> u64 {
750    10000
751}
752
753const fn kinesis_default_init_backoff_ms() -> u64 {
754    1000
755}
756
757const fn kinesis_default_max_backoff_ms() -> u64 {
758    20000
759}
760
761const fn kinesis_default_max_retry_limit() -> u32 {
762    3
763}
764
765impl EnforceSecret for KinesisCommon {
766    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
767        "kinesis.credentials.access",
768        "kinesis.credentials.secret",
769        "kinesis.credentials.session_token",
770    };
771}
772
773impl KinesisCommon {
774    pub(crate) async fn build_client(&self) -> ConnectorResult<KinesisClient> {
775        let config = AwsAuthProps {
776            region: Some(self.stream_region.clone()),
777            endpoint: self.endpoint.clone(),
778            access_key: self.credentials_access_key.clone(),
779            secret_key: self.credentials_secret_access_key.clone(),
780            session_token: self.session_token.clone(),
781            arn: self.assume_role_arn.clone(),
782            external_id: self.assume_role_external_id.clone(),
783            profile: Default::default(),
784            msk_signer_timeout_sec: Default::default(),
785        };
786        let aws_config = config.build_config().await?;
787        let mut builder = aws_sdk_kinesis::config::Builder::from(&aws_config);
788        {
789            // for timeout and retry config
790            let sleep_impl = SharedAsyncSleep::new(KinesisAsyncSleepImpl);
791            builder.set_sleep_impl(Some(sleep_impl));
792            let timeout_config = aws_smithy_types::timeout::TimeoutConfig::builder()
793                .connect_timeout(Duration::from_millis(self.sdk_connect_timeout_ms))
794                .read_timeout(Duration::from_millis(self.sdk_read_timeout_ms))
795                .operation_timeout(Duration::from_millis(self.sdk_operation_timeout_ms))
796                .operation_attempt_timeout(Duration::from_millis(
797                    self.sdk_operation_attempt_timeout_ms,
798                ))
799                .build();
800            builder.set_timeout_config(Some(timeout_config));
801
802            let retry_config = aws_smithy_types::retry::RetryConfig::standard()
803                .with_initial_backoff(Duration::from_millis(self.sdk_init_backoff_ms))
804                .with_max_backoff(Duration::from_millis(self.sdk_max_backoff_ms))
805                .with_max_attempts(self.sdk_max_retry_limit);
806            builder.set_retry_config(Some(retry_config));
807        }
808        if let Some(endpoint) = &config.endpoint {
809            builder = builder.endpoint_url(endpoint);
810        }
811        Ok(KinesisClient::from_conf(builder.build()))
812    }
813}
814
815#[serde_as]
816#[derive(Deserialize, Debug, Clone, WithOptions)]
817pub struct NatsCommon {
818    #[serde(rename = "server_url")]
819    pub server_url: String,
820    #[serde(rename = "subject")]
821    pub subject: String,
822    #[serde(rename = "connect_mode")]
823    pub connect_mode: String,
824    #[serde(rename = "username")]
825    pub user: Option<String>,
826    #[serde(rename = "password")]
827    pub password: Option<String>,
828    #[serde(rename = "jwt")]
829    pub jwt: Option<String>,
830    #[serde(rename = "nkey")]
831    pub nkey: Option<String>,
832    #[serde(rename = "max_bytes")]
833    #[serde_as(as = "Option<DisplayFromStr>")]
834    pub max_bytes: Option<i64>,
835    #[serde(rename = "max_messages")]
836    #[serde_as(as = "Option<DisplayFromStr>")]
837    pub max_messages: Option<i64>,
838    #[serde(rename = "max_messages_per_subject")]
839    #[serde_as(as = "Option<DisplayFromStr>")]
840    pub max_messages_per_subject: Option<i64>,
841    #[serde(rename = "max_consumers")]
842    #[serde_as(as = "Option<DisplayFromStr>")]
843    pub max_consumers: Option<i32>,
844    #[serde(rename = "max_message_size")]
845    #[serde_as(as = "Option<DisplayFromStr>")]
846    pub max_message_size: Option<i32>,
847    #[serde(rename = "allow_create_stream", default)]
848    #[serde_as(as = "DisplayFromStr")]
849    pub allow_create_stream: bool,
850}
851
852impl EnforceSecret for NatsCommon {
853    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
854        "password",
855        "jwt",
856        "nkey",
857    };
858}
859
860impl NatsCommon {
861    pub(crate) async fn build_client(&self) -> ConnectorResult<async_nats::Client> {
862        let mut connect_options = async_nats::ConnectOptions::new();
863        match self.connect_mode.as_str() {
864            "user_and_password" => {
865                if let (Some(v_user), Some(v_password)) =
866                    (self.user.as_ref(), self.password.as_ref())
867                {
868                    connect_options =
869                        connect_options.user_and_password(v_user.into(), v_password.into())
870                } else {
871                    bail!("nats connect mode is user_and_password, but user or password is empty");
872                }
873            }
874
875            "credential" => {
876                if let (Some(v_nkey), Some(v_jwt)) = (self.nkey.as_ref(), self.jwt.as_ref()) {
877                    connect_options = connect_options
878                        .credentials(&self.create_credential(v_nkey, v_jwt)?)
879                        .expect("failed to parse static creds")
880                } else {
881                    bail!("nats connect mode is credential, but nkey or jwt is empty");
882                }
883            }
884            "plain" => {}
885            _ => {
886                bail!("nats connect mode only accepts user_and_password/credential/plain");
887            }
888        };
889
890        let servers = self.server_url.split(',').collect::<Vec<&str>>();
891        let client = connect_options
892            .connect(
893                servers
894                    .iter()
895                    .map(|url| url.parse())
896                    .collect::<Result<Vec<async_nats::ServerAddr>, _>>()?,
897            )
898            .await
899            .context("build nats client error")
900            .map_err(SinkError::Nats)?;
901        Ok(client)
902    }
903
904    pub(crate) async fn build_context(&self) -> ConnectorResult<jetstream::Context> {
905        let client = self.build_client().await?;
906        let jetstream = async_nats::jetstream::new(client);
907        Ok(jetstream)
908    }
909
910    pub(crate) async fn build_consumer(
911        &self,
912        stream: String,
913        durable_consumer_name: String,
914        split_id: String,
915        start_sequence: NatsOffset,
916        mut config: jetstream::consumer::pull::Config,
917    ) -> ConnectorResult<
918        async_nats::jetstream::consumer::Consumer<async_nats::jetstream::consumer::pull::Config>,
919    > {
920        let context = self.build_context().await?;
921        let stream = self.build_or_get_stream(context.clone(), stream).await?;
922        let subject_name = self
923            .subject
924            .replace(',', "-")
925            .replace(['.', '>', '*', ' ', '\t'], "_");
926        let name = format!("risingwave-consumer-{}-{}", subject_name, split_id);
927
928        let deliver_policy = match start_sequence {
929            NatsOffset::Earliest => DeliverPolicy::All,
930            NatsOffset::Latest => DeliverPolicy::New,
931            NatsOffset::SequenceNumber(v) => {
932                // for compatibility, we do not write to any state table now
933                let parsed = v
934                    .parse::<u64>()
935                    .context("failed to parse nats offset as sequence number")?;
936                DeliverPolicy::ByStartSequence {
937                    start_sequence: 1 + parsed,
938                }
939            }
940            NatsOffset::Timestamp(v) => DeliverPolicy::ByStartTime {
941                start_time: OffsetDateTime::from_unix_timestamp_nanos(v as i128 * 1_000_000)
942                    .context("invalid timestamp for nats offset")?,
943            },
944            NatsOffset::None => DeliverPolicy::All,
945        };
946
947        let consumer = match stream.get_consumer(&name).await {
948            Ok(consumer) => consumer,
949            _ => {
950                stream
951                    .get_or_create_consumer(&name, {
952                        config.deliver_policy = deliver_policy;
953                        config.durable_name = Some(durable_consumer_name);
954                        config.filter_subjects =
955                            self.subject.split(',').map(|s| s.to_owned()).collect();
956                        config
957                    })
958                    .await?
959            }
960        };
961        Ok(consumer)
962    }
963
964    pub(crate) async fn build_or_get_stream(
965        &self,
966        jetstream: jetstream::Context,
967        stream_str: String,
968    ) -> ConnectorResult<jetstream::stream::Stream> {
969        let subjects: Vec<String> = self.subject.split(',').map(|s| s.to_owned()).collect();
970
971        // In `SourceEnumerator`, we may create a stream
972        // In `SourceReader`, the desired stream MUST exist
973        if let Ok(mut stream_instance) = jetstream.get_stream(&stream_str).await {
974            tracing::info!(
975                "load existing nats stream ({:?}) with config {:?}",
976                stream_str,
977                stream_instance.info().await?
978            );
979            return Ok(stream_instance);
980        }
981
982        if !self.allow_create_stream {
983            return Err(anyhow!(
984                "stream {} not found, set `allow_create_stream` to true to create a stream",
985                stream_str
986            )
987            .into());
988        }
989
990        let mut config = jetstream::stream::Config {
991            name: stream_str.clone(),
992            max_bytes: 1000000,
993            subjects,
994            ..Default::default()
995        };
996        if let Some(v) = self.max_bytes {
997            config.max_bytes = v;
998        }
999        if let Some(v) = self.max_messages {
1000            config.max_messages = v;
1001        }
1002        if let Some(v) = self.max_messages_per_subject {
1003            config.max_messages_per_subject = v;
1004        }
1005        if let Some(v) = self.max_consumers {
1006            config.max_consumers = v;
1007        }
1008        if let Some(v) = self.max_message_size {
1009            config.max_message_size = v;
1010        }
1011        tracing::info!(
1012            "create nats stream ({:?}) with config {:?}",
1013            &stream_str,
1014            config
1015        );
1016        let stream = jetstream.get_or_create_stream(config).await?;
1017        Ok(stream)
1018    }
1019
1020    pub(crate) fn create_credential(&self, seed: &str, jwt: &str) -> ConnectorResult<String> {
1021        let creds = format!(
1022            "-----BEGIN NATS USER JWT-----\n{}\n------END NATS USER JWT------\n\n\
1023                         ************************* IMPORTANT *************************\n\
1024                         NKEY Seed printed below can be used to sign and prove identity.\n\
1025                         NKEYs are sensitive and should be treated as secrets.\n\n\
1026                         -----BEGIN USER NKEY SEED-----\n{}\n------END USER NKEY SEED------\n\n\
1027                         *************************************************************",
1028            jwt, seed
1029        );
1030        Ok(creds)
1031    }
1032}
1033
1034pub(crate) fn load_certs(
1035    certificates: &str,
1036) -> ConnectorResult<Vec<rustls_pki_types::CertificateDer<'static>>> {
1037    let cert_bytes = if let Some(path) = certificates.strip_prefix("fs://") {
1038        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
1039    } else {
1040        certificates.as_bytes().to_owned()
1041    };
1042
1043    rustls_pemfile::certs(&mut cert_bytes.as_slice())
1044        .map(|cert| Ok(cert?))
1045        .collect()
1046}
1047
1048pub(crate) fn load_private_key(
1049    certificate: &str,
1050) -> ConnectorResult<rustls_pki_types::PrivateKeyDer<'static>> {
1051    let cert_bytes = if let Some(path) = certificate.strip_prefix("fs://") {
1052        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
1053    } else {
1054        certificate.as_bytes().to_owned()
1055    };
1056
1057    let cert = rustls_pemfile::pkcs8_private_keys(&mut cert_bytes.as_slice())
1058        .next()
1059        .ok_or_else(|| anyhow!("No private key found"))?;
1060    Ok(cert?.into())
1061}
1062
1063#[serde_as]
1064#[derive(Deserialize, Debug, Clone, WithOptions)]
1065pub struct MongodbCommon {
1066    /// The URL of `MongoDB`
1067    #[serde(rename = "mongodb.url")]
1068    pub connect_uri: String,
1069    /// The collection name where data should be written to or read from. For sinks, the format is
1070    /// `db_name.collection_name`. Data can also be written to dynamic collections, see `collection.name.field`
1071    /// for more information.
1072    #[serde(rename = "collection.name")]
1073    pub collection_name: String,
1074}
1075
1076impl EnforceSecret for MongodbCommon {
1077    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
1078        "mongodb.url"
1079    };
1080}
1081
1082impl MongodbCommon {
1083    pub(crate) async fn build_client(&self) -> ConnectorResult<mongodb::Client> {
1084        let client = mongodb::Client::with_uri_str(&self.connect_uri).await?;
1085
1086        Ok(client)
1087    }
1088}