risingwave_connector/connector_common/
common.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::hash::Hash;
17use std::io::Write;
18use std::time::Duration;
19
20use anyhow::{Context, anyhow};
21use async_nats::jetstream::consumer::DeliverPolicy;
22use async_nats::jetstream::{self};
23use aws_sdk_kinesis::Client as KinesisClient;
24use aws_sdk_kinesis::config::{AsyncSleep, SharedAsyncSleep, Sleep};
25use phf::{Set, phf_set};
26use pulsar::authentication::oauth2::{OAuth2Authentication, OAuth2Params};
27use pulsar::{Authentication, Pulsar, TokioExecutor};
28use rdkafka::ClientConfig;
29use risingwave_common::bail;
30use serde_derive::Deserialize;
31use serde_with::json::JsonString;
32use serde_with::{DisplayFromStr, serde_as};
33use tempfile::NamedTempFile;
34use time::OffsetDateTime;
35use url::Url;
36use with_options::WithOptions;
37
38use crate::aws_utils::load_file_descriptor_from_s3;
39use crate::deserialize_duration_from_string;
40use crate::enforce_secret::EnforceSecret;
41use crate::error::ConnectorResult;
42use crate::sink::SinkError;
43use crate::source::nats::source::NatsOffset;
44
45pub const PRIVATE_LINK_BROKER_REWRITE_MAP_KEY: &str = "broker.rewrite.endpoints";
46pub const PRIVATE_LINK_TARGETS_KEY: &str = "privatelink.targets";
47
48const AWS_MSK_IAM_AUTH: &str = "AWS_MSK_IAM";
49
50/// The environment variable to disable using default credential from environment.
51/// It's recommended to set this variable to `true` in cloud hosting environment.
52pub const DISABLE_DEFAULT_CREDENTIAL: &str = "DISABLE_DEFAULT_CREDENTIAL";
53
54#[derive(Debug, Clone, Deserialize)]
55pub struct AwsPrivateLinkItem {
56    pub az_id: Option<String>,
57    pub port: u16,
58}
59
60use aws_config::default_provider::region::DefaultRegionChain;
61use aws_config::sts::AssumeRoleProvider;
62use aws_credential_types::provider::SharedCredentialsProvider;
63use aws_types::SdkConfig;
64use aws_types::region::Region;
65use risingwave_common::util::env_var::env_var_is_true;
66
67/// A flatten config map for aws auth.
68#[derive(Deserialize, Debug, Clone, WithOptions, PartialEq)]
69pub struct AwsAuthProps {
70    #[serde(rename = "aws.region", alias = "region", alias = "s3.region")]
71    pub region: Option<String>,
72
73    #[serde(
74        rename = "aws.endpoint_url",
75        alias = "endpoint_url",
76        alias = "endpoint",
77        alias = "s3.endpoint"
78    )]
79    pub endpoint: Option<String>,
80    #[serde(
81        rename = "aws.credentials.access_key_id",
82        alias = "access_key",
83        alias = "s3.access.key"
84    )]
85    pub access_key: Option<String>,
86    #[serde(
87        rename = "aws.credentials.secret_access_key",
88        alias = "secret_key",
89        alias = "s3.secret.key"
90    )]
91    pub secret_key: Option<String>,
92    #[serde(rename = "aws.credentials.session_token", alias = "session_token")]
93    pub session_token: Option<String>,
94    /// IAM role
95    #[serde(rename = "aws.credentials.role.arn", alias = "arn")]
96    pub arn: Option<String>,
97    /// external ID in IAM role trust policy
98    #[serde(rename = "aws.credentials.role.external_id", alias = "external_id")]
99    pub external_id: Option<String>,
100    #[serde(rename = "aws.profile", alias = "profile")]
101    pub profile: Option<String>,
102    #[serde(rename = "aws.msk.signer_timeout_sec")]
103    pub msk_signer_timeout_sec: Option<u64>,
104}
105
106impl EnforceSecret for AwsAuthProps {
107    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
108        "access_key",
109        "aws.credentials.access_key_id",
110        "s3.access.key",
111        "secret_key",
112        "aws.credentials.secret_access_key",
113        "s3.secret.key",
114        "session_token",
115        "aws.credentials.session_token",
116    };
117}
118
119impl AwsAuthProps {
120    async fn build_region(&self) -> ConnectorResult<Region> {
121        if let Some(region_name) = &self.region {
122            Ok(Region::new(region_name.clone()))
123        } else {
124            let mut region_chain = DefaultRegionChain::builder();
125            if let Some(profile_name) = &self.profile {
126                region_chain = region_chain.profile_name(profile_name);
127            }
128
129            Ok(region_chain
130                .build()
131                .region()
132                .await
133                .context("region should be provided")?)
134        }
135    }
136
137    async fn build_credential_provider(&self) -> ConnectorResult<SharedCredentialsProvider> {
138        if self.access_key.is_some() && self.secret_key.is_some() {
139            Ok(SharedCredentialsProvider::new(
140                aws_credential_types::Credentials::from_keys(
141                    self.access_key.as_ref().unwrap(),
142                    self.secret_key.as_ref().unwrap(),
143                    self.session_token.clone(),
144                ),
145            ))
146        } else if !env_var_is_true(DISABLE_DEFAULT_CREDENTIAL) {
147            Ok(SharedCredentialsProvider::new(
148                aws_config::default_provider::credentials::default_provider().await,
149            ))
150        } else {
151            bail!("Both \"access_key\" and \"secret_key\" are required.")
152        }
153    }
154
155    async fn with_role_provider(
156        &self,
157        credential: SharedCredentialsProvider,
158    ) -> ConnectorResult<SharedCredentialsProvider> {
159        if let Some(role_name) = &self.arn {
160            let region = self.build_region().await?;
161            let mut role = AssumeRoleProvider::builder(role_name)
162                .session_name("RisingWave")
163                .region(region);
164            if let Some(id) = &self.external_id {
165                role = role.external_id(id);
166            }
167            let provider = role.build_from_provider(credential).await;
168            Ok(SharedCredentialsProvider::new(provider))
169        } else {
170            Ok(credential)
171        }
172    }
173
174    pub async fn build_config(&self) -> ConnectorResult<SdkConfig> {
175        let region = self.build_region().await?;
176        let credentials_provider = self
177            .with_role_provider(self.build_credential_provider().await?)
178            .await?;
179        let mut config_loader = aws_config::from_env()
180            .region(region)
181            .credentials_provider(credentials_provider);
182
183        if let Some(endpoint) = self.endpoint.as_ref() {
184            config_loader = config_loader.endpoint_url(endpoint);
185        }
186
187        Ok(config_loader.load().await)
188    }
189}
190
191#[serde_as]
192#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
193pub struct KafkaConnectionProps {
194    #[serde(rename = "properties.bootstrap.server", alias = "kafka.brokers")]
195    pub brokers: String,
196
197    /// Security protocol used for RisingWave to communicate with Kafka brokers. Could be
198    /// PLAINTEXT, SSL, SASL_PLAINTEXT or SASL_SSL.
199    #[serde(rename = "properties.security.protocol")]
200    security_protocol: Option<String>,
201
202    #[serde(rename = "properties.ssl.endpoint.identification.algorithm")]
203    ssl_endpoint_identification_algorithm: Option<String>,
204
205    // For the properties below, please refer to [librdkafka](https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md) for more information.
206    /// Path to CA certificate file for verifying the broker's key.
207    #[serde(rename = "properties.ssl.ca.location")]
208    ssl_ca_location: Option<String>,
209
210    /// CA certificate string (PEM format) for verifying the broker's key.
211    #[serde(rename = "properties.ssl.ca.pem")]
212    ssl_ca_pem: Option<String>,
213
214    /// Path to client's certificate file (PEM).
215    #[serde(rename = "properties.ssl.certificate.location")]
216    ssl_certificate_location: Option<String>,
217
218    /// Client's public key string (PEM format) used for authentication.
219    #[serde(rename = "properties.ssl.certificate.pem")]
220    ssl_certificate_pem: Option<String>,
221
222    /// Path to client's private key file (PEM).
223    #[serde(rename = "properties.ssl.key.location")]
224    ssl_key_location: Option<String>,
225
226    /// Client's private key string (PEM format) used for authentication.
227    #[serde(rename = "properties.ssl.key.pem")]
228    ssl_key_pem: Option<String>,
229
230    /// Passphrase of client's private key.
231    #[serde(rename = "properties.ssl.key.password")]
232    ssl_key_password: Option<String>,
233
234    /// SASL mechanism if SASL is enabled. Currently support PLAIN, SCRAM, GSSAPI, and AWS_MSK_IAM.
235    #[serde(rename = "properties.sasl.mechanism")]
236    sasl_mechanism: Option<String>,
237
238    /// SASL username for SASL/PLAIN and SASL/SCRAM.
239    #[serde(rename = "properties.sasl.username")]
240    sasl_username: Option<String>,
241
242    /// SASL password for SASL/PLAIN and SASL/SCRAM.
243    #[serde(rename = "properties.sasl.password")]
244    sasl_password: Option<String>,
245
246    /// Kafka server's Kerberos principal name under SASL/GSSAPI, not including /hostname@REALM.
247    #[serde(rename = "properties.sasl.kerberos.service.name")]
248    sasl_kerberos_service_name: Option<String>,
249
250    /// Path to client's Kerberos keytab file under SASL/GSSAPI.
251    #[serde(rename = "properties.sasl.kerberos.keytab")]
252    sasl_kerberos_keytab: Option<String>,
253
254    /// Client's Kerberos principal name under SASL/GSSAPI.
255    #[serde(rename = "properties.sasl.kerberos.principal")]
256    sasl_kerberos_principal: Option<String>,
257
258    /// Shell command to refresh or acquire the client's Kerberos ticket under SASL/GSSAPI.
259    #[serde(rename = "properties.sasl.kerberos.kinit.cmd")]
260    sasl_kerberos_kinit_cmd: Option<String>,
261
262    /// Minimum time in milliseconds between key refresh attempts under SASL/GSSAPI.
263    #[serde(rename = "properties.sasl.kerberos.min.time.before.relogin")]
264    sasl_kerberos_min_time_before_relogin: Option<String>,
265
266    /// Configurations for SASL/OAUTHBEARER.
267    #[serde(rename = "properties.sasl.oauthbearer.config")]
268    sasl_oathbearer_config: Option<String>,
269}
270
271impl EnforceSecret for KafkaConnectionProps {
272    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
273        "properties.ssl.key.pem",
274        "properties.ssl.key.password",
275        "properties.sasl.password",
276    };
277}
278
279#[serde_as]
280#[derive(Debug, Clone, Deserialize, WithOptions)]
281pub struct KafkaCommon {
282    // connection related props are moved to `KafkaConnection`
283    #[serde(rename = "topic", alias = "kafka.topic")]
284    pub topic: String,
285
286    #[serde(
287        rename = "properties.sync.call.timeout",
288        deserialize_with = "deserialize_duration_from_string",
289        default = "default_kafka_sync_call_timeout"
290    )]
291    pub sync_call_timeout: Duration,
292}
293
294#[serde_as]
295#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
296pub struct KafkaPrivateLinkCommon {
297    /// This is generated from `private_link_targets` and `private_link_endpoint` in frontend, instead of given by users.
298    #[serde(rename = "broker.rewrite.endpoints")]
299    #[serde_as(as = "Option<JsonString>")]
300    pub broker_rewrite_map: Option<BTreeMap<String, String>>,
301}
302
303const fn default_kafka_sync_call_timeout() -> Duration {
304    Duration::from_secs(5)
305}
306
307const fn default_socket_keepalive_enable() -> bool {
308    true
309}
310
311#[serde_as]
312#[derive(Debug, Clone, Deserialize, WithOptions)]
313pub struct RdKafkaPropertiesCommon {
314    /// Maximum Kafka protocol request message size. Due to differing framing overhead between
315    /// protocol versions the producer is unable to reliably enforce a strict max message limit at
316    /// produce time and may exceed the maximum size by one message in protocol ProduceRequests,
317    /// the broker will enforce the topic's max.message.bytes limit
318    #[serde(rename = "properties.message.max.bytes")]
319    #[serde_as(as = "Option<DisplayFromStr>")]
320    pub message_max_bytes: Option<usize>,
321
322    /// Maximum Kafka protocol response message size. This serves as a safety precaution to avoid
323    /// memory exhaustion in case of protocol hickups. This value must be at least fetch.max.bytes
324    /// + 512 to allow for protocol overhead; the value is adjusted automatically unless the
325    /// configuration property is explicitly set.
326    #[serde(rename = "properties.receive.message.max.bytes")]
327    #[serde_as(as = "Option<DisplayFromStr>")]
328    pub receive_message_max_bytes: Option<usize>,
329
330    #[serde(rename = "properties.statistics.interval.ms")]
331    #[serde_as(as = "Option<DisplayFromStr>")]
332    pub statistics_interval_ms: Option<usize>,
333
334    /// Client identifier
335    #[serde(rename = "properties.client.id")]
336    #[serde_as(as = "Option<DisplayFromStr>")]
337    pub client_id: Option<String>,
338
339    #[serde(rename = "properties.enable.ssl.certificate.verification")]
340    #[serde_as(as = "Option<DisplayFromStr>")]
341    pub enable_ssl_certificate_verification: Option<bool>,
342
343    #[serde(
344        rename = "properties.socket.keepalive.enable",
345        default = "default_socket_keepalive_enable"
346    )]
347    #[serde_as(as = "DisplayFromStr")]
348    pub socket_keepalive_enable: bool,
349}
350
351impl RdKafkaPropertiesCommon {
352    pub(crate) fn set_client(&self, c: &mut rdkafka::ClientConfig) {
353        if let Some(v) = self.statistics_interval_ms {
354            c.set("statistics.interval.ms", v.to_string());
355        }
356        if let Some(v) = self.message_max_bytes {
357            c.set("message.max.bytes", v.to_string());
358        }
359        if let Some(v) = self.receive_message_max_bytes {
360            c.set("receive.message.max.bytes", v.to_string());
361        }
362        if let Some(v) = self.client_id.as_ref() {
363            c.set("client.id", v);
364        }
365        if let Some(v) = self.enable_ssl_certificate_verification {
366            c.set("enable.ssl.certificate.verification", v.to_string());
367        }
368        c.set(
369            "socket.keepalive.enable",
370            self.socket_keepalive_enable.to_string(),
371        );
372    }
373}
374
375impl KafkaConnectionProps {
376    #[cfg(test)]
377    pub fn test_default() -> Self {
378        Self {
379            brokers: "localhost:9092".to_owned(),
380            security_protocol: None,
381            ssl_ca_location: None,
382            ssl_certificate_location: None,
383            ssl_key_location: None,
384            ssl_ca_pem: None,
385            ssl_certificate_pem: None,
386            ssl_key_pem: None,
387            ssl_key_password: None,
388            ssl_endpoint_identification_algorithm: None,
389            sasl_mechanism: None,
390            sasl_username: None,
391            sasl_password: None,
392            sasl_kerberos_service_name: None,
393            sasl_kerberos_keytab: None,
394            sasl_kerberos_principal: None,
395            sasl_kerberos_kinit_cmd: None,
396            sasl_kerberos_min_time_before_relogin: None,
397            sasl_oathbearer_config: None,
398        }
399    }
400
401    pub(crate) fn set_security_properties(&self, config: &mut ClientConfig) {
402        // AWS_MSK_IAM
403        if self.is_aws_msk_iam() {
404            config.set("security.protocol", "SASL_SSL");
405            config.set("sasl.mechanism", "OAUTHBEARER");
406            return;
407        }
408
409        // Security protocol
410        if let Some(security_protocol) = self.security_protocol.as_ref() {
411            config.set("security.protocol", security_protocol);
412        }
413
414        // SSL
415        if let Some(ssl_ca_location) = self.ssl_ca_location.as_ref() {
416            config.set("ssl.ca.location", ssl_ca_location);
417        }
418        if let Some(ssl_ca_pem) = self.ssl_ca_pem.as_ref() {
419            config.set("ssl.ca.pem", ssl_ca_pem);
420        }
421        if let Some(ssl_certificate_location) = self.ssl_certificate_location.as_ref() {
422            config.set("ssl.certificate.location", ssl_certificate_location);
423        }
424        if let Some(ssl_certificate_pem) = self.ssl_certificate_pem.as_ref() {
425            config.set("ssl.certificate.pem", ssl_certificate_pem);
426        }
427        if let Some(ssl_key_location) = self.ssl_key_location.as_ref() {
428            config.set("ssl.key.location", ssl_key_location);
429        }
430        if let Some(ssl_key_pem) = self.ssl_key_pem.as_ref() {
431            config.set("ssl.key.pem", ssl_key_pem);
432        }
433        if let Some(ssl_key_password) = self.ssl_key_password.as_ref() {
434            config.set("ssl.key.password", ssl_key_password);
435        }
436        if let Some(ssl_endpoint_identification_algorithm) =
437            self.ssl_endpoint_identification_algorithm.as_ref()
438        {
439            // accept only `none` and `http` here, let the sdk do the check
440            config.set(
441                "ssl.endpoint.identification.algorithm",
442                ssl_endpoint_identification_algorithm,
443            );
444        }
445
446        // SASL mechanism
447        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref() {
448            config.set("sasl.mechanism", sasl_mechanism);
449        }
450
451        // SASL/PLAIN & SASL/SCRAM
452        if let Some(sasl_username) = self.sasl_username.as_ref() {
453            config.set("sasl.username", sasl_username);
454        }
455        if let Some(sasl_password) = self.sasl_password.as_ref() {
456            config.set("sasl.password", sasl_password);
457        }
458
459        // SASL/GSSAPI
460        if let Some(sasl_kerberos_service_name) = self.sasl_kerberos_service_name.as_ref() {
461            config.set("sasl.kerberos.service.name", sasl_kerberos_service_name);
462        }
463        if let Some(sasl_kerberos_keytab) = self.sasl_kerberos_keytab.as_ref() {
464            config.set("sasl.kerberos.keytab", sasl_kerberos_keytab);
465        }
466        if let Some(sasl_kerberos_principal) = self.sasl_kerberos_principal.as_ref() {
467            config.set("sasl.kerberos.principal", sasl_kerberos_principal);
468        }
469        if let Some(sasl_kerberos_kinit_cmd) = self.sasl_kerberos_kinit_cmd.as_ref() {
470            config.set("sasl.kerberos.kinit.cmd", sasl_kerberos_kinit_cmd);
471        }
472        if let Some(sasl_kerberos_min_time_before_relogin) =
473            self.sasl_kerberos_min_time_before_relogin.as_ref()
474        {
475            config.set(
476                "sasl.kerberos.min.time.before.relogin",
477                sasl_kerberos_min_time_before_relogin,
478            );
479        }
480
481        // SASL/OAUTHBEARER
482        if let Some(sasl_oathbearer_config) = self.sasl_oathbearer_config.as_ref() {
483            config.set("sasl.oauthbearer.config", sasl_oathbearer_config);
484        }
485        // Currently, we only support unsecured OAUTH.
486        config.set("enable.sasl.oauthbearer.unsecure.jwt", "true");
487    }
488
489    pub(crate) fn is_aws_msk_iam(&self) -> bool {
490        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref()
491            && sasl_mechanism == AWS_MSK_IAM_AUTH
492        {
493            true
494        } else {
495            false
496        }
497    }
498}
499
500#[derive(Clone, Debug, Deserialize, WithOptions)]
501pub struct PulsarCommon {
502    #[serde(rename = "topic", alias = "pulsar.topic")]
503    pub topic: String,
504
505    #[serde(rename = "service.url", alias = "pulsar.service.url")]
506    pub service_url: String,
507
508    #[serde(rename = "auth.token")]
509    pub auth_token: Option<String>,
510}
511
512impl EnforceSecret for PulsarCommon {
513    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
514        "pulsar.auth.token",
515    };
516}
517
518#[derive(Clone, Debug, Deserialize, WithOptions)]
519pub struct PulsarOauthCommon {
520    #[serde(rename = "oauth.issuer.url")]
521    pub issuer_url: String,
522
523    #[serde(rename = "oauth.credentials.url")]
524    pub credentials_url: String,
525
526    #[serde(rename = "oauth.audience")]
527    pub audience: String,
528
529    #[serde(rename = "oauth.scope")]
530    pub scope: Option<String>,
531}
532
533fn create_credential_temp_file(credentials: &[u8]) -> std::io::Result<NamedTempFile> {
534    let mut f = NamedTempFile::new()?;
535    f.write_all(credentials)?;
536    f.as_file().sync_all()?;
537    Ok(f)
538}
539
540impl PulsarCommon {
541    pub(crate) async fn build_client(
542        &self,
543        oauth: &Option<PulsarOauthCommon>,
544        aws_auth_props: &AwsAuthProps,
545    ) -> ConnectorResult<Pulsar<TokioExecutor>> {
546        let mut pulsar_builder = Pulsar::builder(&self.service_url, TokioExecutor);
547        let mut temp_file = None;
548        if let Some(oauth) = oauth.as_ref() {
549            let url = Url::parse(&oauth.credentials_url)?;
550            match url.scheme() {
551                "s3" => {
552                    let credentials = load_file_descriptor_from_s3(&url, aws_auth_props).await?;
553                    temp_file = Some(
554                        create_credential_temp_file(&credentials)
555                            .context("failed to create temp file for pulsar credentials")?,
556                    );
557                }
558                "file" => {}
559                _ => {
560                    bail!("invalid credentials_url, only file url and s3 url are supported",);
561                }
562            }
563
564            let auth_params = OAuth2Params {
565                issuer_url: oauth.issuer_url.clone(),
566                credentials_url: if temp_file.is_none() {
567                    oauth.credentials_url.clone()
568                } else {
569                    let mut raw_path = temp_file
570                        .as_ref()
571                        .unwrap()
572                        .path()
573                        .to_str()
574                        .unwrap()
575                        .to_owned();
576                    raw_path.insert_str(0, "file://");
577                    raw_path
578                },
579                audience: Some(oauth.audience.clone()),
580                scope: oauth.scope.clone(),
581            };
582
583            pulsar_builder = pulsar_builder
584                .with_auth_provider(OAuth2Authentication::client_credentials(auth_params));
585        } else if let Some(auth_token) = &self.auth_token {
586            pulsar_builder = pulsar_builder.with_auth(Authentication {
587                name: "token".to_owned(),
588                data: Vec::from(auth_token.as_str()),
589            });
590        }
591
592        let res = pulsar_builder.build().await.map_err(|e| anyhow!(e))?;
593        drop(temp_file);
594        Ok(res)
595    }
596}
597
598#[derive(Deserialize, Debug, Clone, WithOptions)]
599pub struct KinesisCommon {
600    #[serde(rename = "stream", alias = "kinesis.stream.name")]
601    pub stream_name: String,
602    #[serde(rename = "aws.region", alias = "kinesis.stream.region")]
603    pub stream_region: String,
604    #[serde(rename = "endpoint", alias = "kinesis.endpoint")]
605    pub endpoint: Option<String>,
606    #[serde(
607        rename = "aws.credentials.access_key_id",
608        alias = "kinesis.credentials.access"
609    )]
610    pub credentials_access_key: Option<String>,
611    #[serde(
612        rename = "aws.credentials.secret_access_key",
613        alias = "kinesis.credentials.secret"
614    )]
615    pub credentials_secret_access_key: Option<String>,
616    #[serde(
617        rename = "aws.credentials.session_token",
618        alias = "kinesis.credentials.session_token"
619    )]
620    pub session_token: Option<String>,
621    #[serde(rename = "aws.credentials.role.arn", alias = "kinesis.assumerole.arn")]
622    pub assume_role_arn: Option<String>,
623    #[serde(
624        rename = "aws.credentials.role.external_id",
625        alias = "kinesis.assumerole.external_id"
626    )]
627    pub assume_role_external_id: Option<String>,
628
629    #[serde(flatten)]
630    pub sdk_options: KinesisSdkOptions,
631}
632
633#[derive(Debug)]
634pub struct KinesisAsyncSleepImpl;
635
636impl AsyncSleep for KinesisAsyncSleepImpl {
637    fn sleep(&self, duration: Duration) -> Sleep {
638        Sleep::new(async move { tokio::time::sleep(duration).await })
639    }
640}
641
642const fn kinesis_default_connect_timeout_ms() -> u64 {
643    10000
644}
645
646const fn kinesis_default_read_timeout_ms() -> u64 {
647    10000
648}
649
650const fn kinesis_default_operation_timeout_ms() -> u64 {
651    10000
652}
653
654const fn kinesis_default_operation_attempt_timeout_ms() -> u64 {
655    10000
656}
657
658const fn kinesis_default_init_backoff_ms() -> u64 {
659    1000
660}
661
662const fn kinesis_default_max_backoff_ms() -> u64 {
663    20000
664}
665
666const fn kinesis_default_max_retry_limit() -> u32 {
667    3
668}
669
670#[serde_as]
671#[derive(Clone, Debug, Deserialize, WithOptions)]
672pub struct KinesisSdkOptions {
673    #[serde(
674        rename = "kinesis.sdk.connect_timeout_ms",
675        default = "kinesis_default_connect_timeout_ms"
676    )]
677    #[serde_as(as = "DisplayFromStr")]
678    pub sdk_connect_timeout_ms: u64,
679
680    #[serde(
681        rename = "kinesis.sdk.read_timeout_ms",
682        default = "kinesis_default_read_timeout_ms"
683    )]
684    #[serde_as(as = "DisplayFromStr")]
685    pub sdk_read_timeout_ms: u64,
686
687    #[serde(
688        rename = "kinesis.sdk.operation_timeout_ms",
689        default = "kinesis_default_operation_timeout_ms"
690    )]
691    #[serde_as(as = "DisplayFromStr")]
692    pub sdk_operation_timeout_ms: u64,
693
694    #[serde(
695        rename = "kinesis.sdk.operation_attempt_timeout_ms",
696        default = "kinesis_default_operation_attempt_timeout_ms"
697    )]
698    #[serde_as(as = "DisplayFromStr")]
699    pub sdk_operation_attempt_timeout_ms: u64,
700
701    #[serde(
702        rename = "kinesis.sdk.max_retry_limit",
703        default = "kinesis_default_max_retry_limit"
704    )]
705    #[serde_as(as = "DisplayFromStr")]
706    pub sdk_max_retry_limit: u32,
707
708    #[serde(
709        rename = "kinesis.sdk.init_backoff_ms",
710        default = "kinesis_default_init_backoff_ms"
711    )]
712    #[serde_as(as = "DisplayFromStr")]
713    pub sdk_init_backoff_ms: u64,
714
715    #[serde(
716        rename = "kinesis.sdk.max_backoff_ms",
717        default = "kinesis_default_max_backoff_ms"
718    )]
719    #[serde_as(as = "DisplayFromStr")]
720    pub sdk_max_backoff_ms: u64,
721}
722
723impl Default for KinesisSdkOptions {
724    fn default() -> Self {
725        Self {
726            sdk_connect_timeout_ms: kinesis_default_connect_timeout_ms(),
727            sdk_read_timeout_ms: kinesis_default_read_timeout_ms(),
728            sdk_operation_timeout_ms: kinesis_default_operation_timeout_ms(),
729            sdk_operation_attempt_timeout_ms: kinesis_default_operation_attempt_timeout_ms(),
730            sdk_max_retry_limit: kinesis_default_max_retry_limit(),
731            sdk_init_backoff_ms: kinesis_default_init_backoff_ms(),
732            sdk_max_backoff_ms: kinesis_default_max_backoff_ms(),
733        }
734    }
735}
736
737impl EnforceSecret for KinesisCommon {
738    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
739        "kinesis.credentials.access",
740        "kinesis.credentials.secret",
741        "kinesis.credentials.session_token",
742    };
743}
744
745impl KinesisCommon {
746    pub(crate) async fn build_client(&self) -> ConnectorResult<KinesisClient> {
747        let config = AwsAuthProps {
748            region: Some(self.stream_region.clone()),
749            endpoint: self.endpoint.clone(),
750            access_key: self.credentials_access_key.clone(),
751            secret_key: self.credentials_secret_access_key.clone(),
752            session_token: self.session_token.clone(),
753            arn: self.assume_role_arn.clone(),
754            external_id: self.assume_role_external_id.clone(),
755            profile: Default::default(),
756            msk_signer_timeout_sec: Default::default(),
757        };
758        let aws_config = config.build_config().await?;
759        let mut builder = aws_sdk_kinesis::config::Builder::from(&aws_config);
760        {
761            // for timeout and retry config
762            let sleep_impl = SharedAsyncSleep::new(KinesisAsyncSleepImpl);
763            builder.set_sleep_impl(Some(sleep_impl));
764            let timeout_config = aws_smithy_types::timeout::TimeoutConfig::builder()
765                .connect_timeout(Duration::from_millis(
766                    self.sdk_options.sdk_connect_timeout_ms,
767                ))
768                .read_timeout(Duration::from_millis(self.sdk_options.sdk_read_timeout_ms))
769                .operation_timeout(Duration::from_millis(
770                    self.sdk_options.sdk_operation_timeout_ms,
771                ))
772                .operation_attempt_timeout(Duration::from_millis(
773                    self.sdk_options.sdk_operation_attempt_timeout_ms,
774                ))
775                .build();
776            builder.set_timeout_config(Some(timeout_config));
777
778            let retry_config = aws_smithy_types::retry::RetryConfig::standard()
779                .with_initial_backoff(Duration::from_millis(self.sdk_options.sdk_init_backoff_ms))
780                .with_max_backoff(Duration::from_millis(self.sdk_options.sdk_max_backoff_ms))
781                .with_max_attempts(self.sdk_options.sdk_max_retry_limit);
782            builder.set_retry_config(Some(retry_config));
783        }
784        if let Some(endpoint) = &config.endpoint {
785            builder = builder.endpoint_url(endpoint);
786        }
787        Ok(KinesisClient::from_conf(builder.build()))
788    }
789}
790
791#[serde_as]
792#[derive(Deserialize, Debug, Clone, WithOptions)]
793pub struct NatsCommon {
794    #[serde(rename = "server_url")]
795    pub server_url: String,
796    #[serde(rename = "subject")]
797    pub subject: String,
798    #[serde(rename = "connect_mode")]
799    pub connect_mode: String,
800    #[serde(rename = "username")]
801    pub user: Option<String>,
802    #[serde(rename = "password")]
803    pub password: Option<String>,
804    #[serde(rename = "jwt")]
805    pub jwt: Option<String>,
806    #[serde(rename = "nkey")]
807    pub nkey: Option<String>,
808    #[serde(rename = "max_bytes")]
809    #[serde_as(as = "Option<DisplayFromStr>")]
810    pub max_bytes: Option<i64>,
811    #[serde(rename = "max_messages")]
812    #[serde_as(as = "Option<DisplayFromStr>")]
813    pub max_messages: Option<i64>,
814    #[serde(rename = "max_messages_per_subject")]
815    #[serde_as(as = "Option<DisplayFromStr>")]
816    pub max_messages_per_subject: Option<i64>,
817    #[serde(rename = "max_consumers")]
818    #[serde_as(as = "Option<DisplayFromStr>")]
819    pub max_consumers: Option<i32>,
820    #[serde(rename = "max_message_size")]
821    #[serde_as(as = "Option<DisplayFromStr>")]
822    pub max_message_size: Option<i32>,
823}
824
825impl EnforceSecret for NatsCommon {
826    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
827        "password",
828        "jwt",
829        "nkey",
830    };
831}
832
833impl NatsCommon {
834    pub(crate) async fn build_client(&self) -> ConnectorResult<async_nats::Client> {
835        let mut connect_options = async_nats::ConnectOptions::new();
836        match self.connect_mode.as_str() {
837            "user_and_password" => {
838                if let (Some(v_user), Some(v_password)) =
839                    (self.user.as_ref(), self.password.as_ref())
840                {
841                    connect_options =
842                        connect_options.user_and_password(v_user.into(), v_password.into())
843                } else {
844                    bail!("nats connect mode is user_and_password, but user or password is empty");
845                }
846            }
847
848            "credential" => {
849                if let (Some(v_nkey), Some(v_jwt)) = (self.nkey.as_ref(), self.jwt.as_ref()) {
850                    connect_options = connect_options
851                        .credentials(&self.create_credential(v_nkey, v_jwt)?)
852                        .expect("failed to parse static creds")
853                } else {
854                    bail!("nats connect mode is credential, but nkey or jwt is empty");
855                }
856            }
857            "plain" => {}
858            _ => {
859                bail!("nats connect mode only accepts user_and_password/credential/plain");
860            }
861        };
862
863        let servers = self.server_url.split(',').collect::<Vec<&str>>();
864        let client = connect_options
865            .connect(
866                servers
867                    .iter()
868                    .map(|url| url.parse())
869                    .collect::<Result<Vec<async_nats::ServerAddr>, _>>()?,
870            )
871            .await
872            .context("build nats client error")
873            .map_err(SinkError::Nats)?;
874        Ok(client)
875    }
876
877    pub(crate) async fn build_context(&self) -> ConnectorResult<jetstream::Context> {
878        let client = self.build_client().await?;
879        let jetstream = async_nats::jetstream::new(client);
880        Ok(jetstream)
881    }
882
883    pub(crate) async fn build_consumer(
884        &self,
885        stream: String,
886        durable_consumer_name: String,
887        split_id: String,
888        start_sequence: NatsOffset,
889        mut config: jetstream::consumer::pull::Config,
890    ) -> ConnectorResult<
891        async_nats::jetstream::consumer::Consumer<async_nats::jetstream::consumer::pull::Config>,
892    > {
893        let context = self.build_context().await?;
894        let stream = self.build_or_get_stream(context.clone(), stream).await?;
895        let subject_name = self
896            .subject
897            .replace(',', "-")
898            .replace(['.', '>', '*', ' ', '\t'], "_");
899        let name = format!("risingwave-consumer-{}-{}", subject_name, split_id);
900
901        let deliver_policy = match start_sequence {
902            NatsOffset::Earliest => DeliverPolicy::All,
903            NatsOffset::Latest => DeliverPolicy::New,
904            NatsOffset::SequenceNumber(v) => {
905                // for compatibility, we do not write to any state table now
906                let parsed = v
907                    .parse::<u64>()
908                    .context("failed to parse nats offset as sequence number")?;
909                DeliverPolicy::ByStartSequence {
910                    start_sequence: 1 + parsed,
911                }
912            }
913            NatsOffset::Timestamp(v) => DeliverPolicy::ByStartTime {
914                start_time: OffsetDateTime::from_unix_timestamp_nanos(v as i128 * 1_000_000)
915                    .context("invalid timestamp for nats offset")?,
916            },
917            NatsOffset::None => DeliverPolicy::All,
918        };
919
920        let consumer = match stream.get_consumer(&name).await {
921            Ok(consumer) => consumer,
922            _ => {
923                stream
924                    .get_or_create_consumer(&name, {
925                        config.deliver_policy = deliver_policy;
926                        config.durable_name = Some(durable_consumer_name);
927                        config.filter_subjects =
928                            self.subject.split(',').map(|s| s.to_owned()).collect();
929                        config
930                    })
931                    .await?
932            }
933        };
934        Ok(consumer)
935    }
936
937    pub(crate) async fn build_or_get_stream(
938        &self,
939        jetstream: jetstream::Context,
940        stream: String,
941    ) -> ConnectorResult<jetstream::stream::Stream> {
942        let subjects: Vec<String> = self.subject.split(',').map(|s| s.to_owned()).collect();
943        if let Ok(mut stream_instance) = jetstream.get_stream(&stream).await {
944            tracing::info!(
945                "load existing nats stream ({:?}) with config {:?}",
946                stream,
947                stream_instance.info().await?
948            );
949            return Ok(stream_instance);
950        }
951
952        let mut config = jetstream::stream::Config {
953            name: stream.clone(),
954            max_bytes: 1000000,
955            subjects,
956            ..Default::default()
957        };
958        if let Some(v) = self.max_bytes {
959            config.max_bytes = v;
960        }
961        if let Some(v) = self.max_messages {
962            config.max_messages = v;
963        }
964        if let Some(v) = self.max_messages_per_subject {
965            config.max_messages_per_subject = v;
966        }
967        if let Some(v) = self.max_consumers {
968            config.max_consumers = v;
969        }
970        if let Some(v) = self.max_message_size {
971            config.max_message_size = v;
972        }
973        tracing::info!(
974            "create nats stream ({:?}) with config {:?}",
975            &stream,
976            config
977        );
978        let stream = jetstream.get_or_create_stream(config).await?;
979        Ok(stream)
980    }
981
982    pub(crate) fn create_credential(&self, seed: &str, jwt: &str) -> ConnectorResult<String> {
983        let creds = format!(
984            "-----BEGIN NATS USER JWT-----\n{}\n------END NATS USER JWT------\n\n\
985                         ************************* IMPORTANT *************************\n\
986                         NKEY Seed printed below can be used to sign and prove identity.\n\
987                         NKEYs are sensitive and should be treated as secrets.\n\n\
988                         -----BEGIN USER NKEY SEED-----\n{}\n------END USER NKEY SEED------\n\n\
989                         *************************************************************",
990            jwt, seed
991        );
992        Ok(creds)
993    }
994}
995
996pub(crate) fn load_certs(
997    certificates: &str,
998) -> ConnectorResult<Vec<rustls_pki_types::CertificateDer<'static>>> {
999    let cert_bytes = if let Some(path) = certificates.strip_prefix("fs://") {
1000        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
1001    } else {
1002        certificates.as_bytes().to_owned()
1003    };
1004
1005    rustls_pemfile::certs(&mut cert_bytes.as_slice())
1006        .map(|cert| Ok(cert?))
1007        .collect()
1008}
1009
1010pub(crate) fn load_private_key(
1011    certificate: &str,
1012) -> ConnectorResult<rustls_pki_types::PrivateKeyDer<'static>> {
1013    let cert_bytes = if let Some(path) = certificate.strip_prefix("fs://") {
1014        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
1015    } else {
1016        certificate.as_bytes().to_owned()
1017    };
1018
1019    let cert = rustls_pemfile::pkcs8_private_keys(&mut cert_bytes.as_slice())
1020        .next()
1021        .ok_or_else(|| anyhow!("No private key found"))?;
1022    Ok(cert?.into())
1023}
1024
1025#[serde_as]
1026#[derive(Deserialize, Debug, Clone, WithOptions)]
1027pub struct MongodbCommon {
1028    /// The URL of MongoDB
1029    #[serde(rename = "mongodb.url")]
1030    pub connect_uri: String,
1031    /// The collection name where data should be written to or read from. For sinks, the format is
1032    /// `db_name.collection_name`. Data can also be written to dynamic collections, see `collection.name.field`
1033    /// for more information.
1034    #[serde(rename = "collection.name")]
1035    pub collection_name: String,
1036}
1037
1038impl EnforceSecret for MongodbCommon {
1039    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
1040        "mongodb.url"
1041    };
1042}
1043
1044impl MongodbCommon {
1045    pub(crate) async fn build_client(&self) -> ConnectorResult<mongodb::Client> {
1046        let client = mongodb::Client::with_uri_str(&self.connect_uri).await?;
1047
1048        Ok(client)
1049    }
1050}