risingwave_connector/connector_common/
common.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::hash::Hash;
17use std::io::Write;
18use std::time::Duration;
19
20use anyhow::{Context, anyhow};
21use async_nats::jetstream::consumer::DeliverPolicy;
22use async_nats::jetstream::{self};
23use aws_sdk_kinesis::Client as KinesisClient;
24use phf::{Set, phf_set};
25use pulsar::authentication::oauth2::{OAuth2Authentication, OAuth2Params};
26use pulsar::{Authentication, Pulsar, TokioExecutor};
27use rdkafka::ClientConfig;
28use risingwave_common::bail;
29use serde_derive::Deserialize;
30use serde_with::json::JsonString;
31use serde_with::{DisplayFromStr, serde_as};
32use tempfile::NamedTempFile;
33use time::OffsetDateTime;
34use url::Url;
35use with_options::WithOptions;
36
37use crate::aws_utils::load_file_descriptor_from_s3;
38use crate::deserialize_duration_from_string;
39use crate::enforce_secret::EnforceSecret;
40use crate::error::ConnectorResult;
41use crate::sink::SinkError;
42use crate::source::nats::source::NatsOffset;
43
44pub const PRIVATE_LINK_BROKER_REWRITE_MAP_KEY: &str = "broker.rewrite.endpoints";
45pub const PRIVATE_LINK_TARGETS_KEY: &str = "privatelink.targets";
46
47const AWS_MSK_IAM_AUTH: &str = "AWS_MSK_IAM";
48
49/// The environment variable to disable using default credential from environment.
50/// It's recommended to set this variable to `false` in cloud hosting environment.
51const DISABLE_DEFAULT_CREDENTIAL: &str = "DISABLE_DEFAULT_CREDENTIAL";
52
53#[derive(Debug, Clone, Deserialize)]
54pub struct AwsPrivateLinkItem {
55    pub az_id: Option<String>,
56    pub port: u16,
57}
58
59use aws_config::default_provider::region::DefaultRegionChain;
60use aws_config::sts::AssumeRoleProvider;
61use aws_credential_types::provider::SharedCredentialsProvider;
62use aws_types::SdkConfig;
63use aws_types::region::Region;
64use risingwave_common::util::env_var::env_var_is_true;
65
66/// A flatten config map for aws auth.
67#[derive(Deserialize, Debug, Clone, WithOptions, PartialEq)]
68pub struct AwsAuthProps {
69    #[serde(rename = "aws.region", alias = "region", alias = "s3.region")]
70    pub region: Option<String>,
71
72    #[serde(
73        rename = "aws.endpoint_url",
74        alias = "endpoint_url",
75        alias = "endpoint",
76        alias = "s3.endpoint"
77    )]
78    pub endpoint: Option<String>,
79    #[serde(
80        rename = "aws.credentials.access_key_id",
81        alias = "access_key",
82        alias = "s3.access.key"
83    )]
84    pub access_key: Option<String>,
85    #[serde(
86        rename = "aws.credentials.secret_access_key",
87        alias = "secret_key",
88        alias = "s3.secret.key"
89    )]
90    pub secret_key: Option<String>,
91    #[serde(rename = "aws.credentials.session_token", alias = "session_token")]
92    pub session_token: Option<String>,
93    /// IAM role
94    #[serde(rename = "aws.credentials.role.arn", alias = "arn")]
95    pub arn: Option<String>,
96    /// external ID in IAM role trust policy
97    #[serde(rename = "aws.credentials.role.external_id", alias = "external_id")]
98    pub external_id: Option<String>,
99    #[serde(rename = "aws.profile", alias = "profile")]
100    pub profile: Option<String>,
101    #[serde(rename = "aws.msk.signer_timeout_sec")]
102    pub msk_signer_timeout_sec: Option<u64>,
103}
104
105impl EnforceSecret for AwsAuthProps {
106    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
107        "access_key",
108        "aws.credentials.access_key_id",
109        "s3.access.key",
110        "secret_key",
111        "aws.credentials.secret_access_key",
112        "s3.secret.key",
113        "session_token",
114        "aws.credentials.session_token",
115    };
116}
117
118impl AwsAuthProps {
119    async fn build_region(&self) -> ConnectorResult<Region> {
120        if let Some(region_name) = &self.region {
121            Ok(Region::new(region_name.clone()))
122        } else {
123            let mut region_chain = DefaultRegionChain::builder();
124            if let Some(profile_name) = &self.profile {
125                region_chain = region_chain.profile_name(profile_name);
126            }
127
128            Ok(region_chain
129                .build()
130                .region()
131                .await
132                .context("region should be provided")?)
133        }
134    }
135
136    async fn build_credential_provider(&self) -> ConnectorResult<SharedCredentialsProvider> {
137        if self.access_key.is_some() && self.secret_key.is_some() {
138            Ok(SharedCredentialsProvider::new(
139                aws_credential_types::Credentials::from_keys(
140                    self.access_key.as_ref().unwrap(),
141                    self.secret_key.as_ref().unwrap(),
142                    self.session_token.clone(),
143                ),
144            ))
145        } else if !env_var_is_true(DISABLE_DEFAULT_CREDENTIAL) {
146            Ok(SharedCredentialsProvider::new(
147                aws_config::default_provider::credentials::default_provider().await,
148            ))
149        } else {
150            bail!("Both \"access_key\" and \"secret_key\" are required.")
151        }
152    }
153
154    async fn with_role_provider(
155        &self,
156        credential: SharedCredentialsProvider,
157    ) -> ConnectorResult<SharedCredentialsProvider> {
158        if let Some(role_name) = &self.arn {
159            let region = self.build_region().await?;
160            let mut role = AssumeRoleProvider::builder(role_name)
161                .session_name("RisingWave")
162                .region(region);
163            if let Some(id) = &self.external_id {
164                role = role.external_id(id);
165            }
166            let provider = role.build_from_provider(credential).await;
167            Ok(SharedCredentialsProvider::new(provider))
168        } else {
169            Ok(credential)
170        }
171    }
172
173    pub async fn build_config(&self) -> ConnectorResult<SdkConfig> {
174        let region = self.build_region().await?;
175        let credentials_provider = self
176            .with_role_provider(self.build_credential_provider().await?)
177            .await?;
178        let mut config_loader = aws_config::from_env()
179            .region(region)
180            .credentials_provider(credentials_provider);
181
182        if let Some(endpoint) = self.endpoint.as_ref() {
183            config_loader = config_loader.endpoint_url(endpoint);
184        }
185
186        Ok(config_loader.load().await)
187    }
188}
189
190#[serde_as]
191#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
192pub struct KafkaConnectionProps {
193    #[serde(rename = "properties.bootstrap.server", alias = "kafka.brokers")]
194    pub brokers: String,
195
196    /// Security protocol used for RisingWave to communicate with Kafka brokers. Could be
197    /// PLAINTEXT, SSL, SASL_PLAINTEXT or SASL_SSL.
198    #[serde(rename = "properties.security.protocol")]
199    security_protocol: Option<String>,
200
201    #[serde(rename = "properties.ssl.endpoint.identification.algorithm")]
202    ssl_endpoint_identification_algorithm: Option<String>,
203
204    // For the properties below, please refer to [librdkafka](https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md) for more information.
205    /// Path to CA certificate file for verifying the broker's key.
206    #[serde(rename = "properties.ssl.ca.location")]
207    ssl_ca_location: Option<String>,
208
209    /// CA certificate string (PEM format) for verifying the broker's key.
210    #[serde(rename = "properties.ssl.ca.pem")]
211    ssl_ca_pem: Option<String>,
212
213    /// Path to client's certificate file (PEM).
214    #[serde(rename = "properties.ssl.certificate.location")]
215    ssl_certificate_location: Option<String>,
216
217    /// Client's public key string (PEM format) used for authentication.
218    #[serde(rename = "properties.ssl.certificate.pem")]
219    ssl_certificate_pem: Option<String>,
220
221    /// Path to client's private key file (PEM).
222    #[serde(rename = "properties.ssl.key.location")]
223    ssl_key_location: Option<String>,
224
225    /// Client's private key string (PEM format) used for authentication.
226    #[serde(rename = "properties.ssl.key.pem")]
227    ssl_key_pem: Option<String>,
228
229    /// Passphrase of client's private key.
230    #[serde(rename = "properties.ssl.key.password")]
231    ssl_key_password: Option<String>,
232
233    /// SASL mechanism if SASL is enabled. Currently support PLAIN, SCRAM, GSSAPI, and AWS_MSK_IAM.
234    #[serde(rename = "properties.sasl.mechanism")]
235    sasl_mechanism: Option<String>,
236
237    /// SASL username for SASL/PLAIN and SASL/SCRAM.
238    #[serde(rename = "properties.sasl.username")]
239    sasl_username: Option<String>,
240
241    /// SASL password for SASL/PLAIN and SASL/SCRAM.
242    #[serde(rename = "properties.sasl.password")]
243    sasl_password: Option<String>,
244
245    /// Kafka server's Kerberos principal name under SASL/GSSAPI, not including /hostname@REALM.
246    #[serde(rename = "properties.sasl.kerberos.service.name")]
247    sasl_kerberos_service_name: Option<String>,
248
249    /// Path to client's Kerberos keytab file under SASL/GSSAPI.
250    #[serde(rename = "properties.sasl.kerberos.keytab")]
251    sasl_kerberos_keytab: Option<String>,
252
253    /// Client's Kerberos principal name under SASL/GSSAPI.
254    #[serde(rename = "properties.sasl.kerberos.principal")]
255    sasl_kerberos_principal: Option<String>,
256
257    /// Shell command to refresh or acquire the client's Kerberos ticket under SASL/GSSAPI.
258    #[serde(rename = "properties.sasl.kerberos.kinit.cmd")]
259    sasl_kerberos_kinit_cmd: Option<String>,
260
261    /// Minimum time in milliseconds between key refresh attempts under SASL/GSSAPI.
262    #[serde(rename = "properties.sasl.kerberos.min.time.before.relogin")]
263    sasl_kerberos_min_time_before_relogin: Option<String>,
264
265    /// Configurations for SASL/OAUTHBEARER.
266    #[serde(rename = "properties.sasl.oauthbearer.config")]
267    sasl_oathbearer_config: Option<String>,
268}
269
270impl EnforceSecret for KafkaConnectionProps {
271    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
272        "properties.ssl.key.pem",
273        "properties.ssl.key.password",
274        "properties.sasl.password",
275    };
276}
277
278#[serde_as]
279#[derive(Debug, Clone, Deserialize, WithOptions)]
280pub struct KafkaCommon {
281    // connection related props are moved to `KafkaConnection`
282    #[serde(rename = "topic", alias = "kafka.topic")]
283    pub topic: String,
284
285    #[serde(
286        rename = "properties.sync.call.timeout",
287        deserialize_with = "deserialize_duration_from_string",
288        default = "default_kafka_sync_call_timeout"
289    )]
290    pub sync_call_timeout: Duration,
291}
292
293#[serde_as]
294#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
295pub struct KafkaPrivateLinkCommon {
296    /// This is generated from `private_link_targets` and `private_link_endpoint` in frontend, instead of given by users.
297    #[serde(rename = "broker.rewrite.endpoints")]
298    #[serde_as(as = "Option<JsonString>")]
299    pub broker_rewrite_map: Option<BTreeMap<String, String>>,
300}
301
302const fn default_kafka_sync_call_timeout() -> Duration {
303    Duration::from_secs(5)
304}
305
306const fn default_socket_keepalive_enable() -> bool {
307    true
308}
309
310#[serde_as]
311#[derive(Debug, Clone, Deserialize, WithOptions)]
312pub struct RdKafkaPropertiesCommon {
313    /// Maximum Kafka protocol request message size. Due to differing framing overhead between
314    /// protocol versions the producer is unable to reliably enforce a strict max message limit at
315    /// produce time and may exceed the maximum size by one message in protocol ProduceRequests,
316    /// the broker will enforce the topic's max.message.bytes limit
317    #[serde(rename = "properties.message.max.bytes")]
318    #[serde_as(as = "Option<DisplayFromStr>")]
319    pub message_max_bytes: Option<usize>,
320
321    /// Maximum Kafka protocol response message size. This serves as a safety precaution to avoid
322    /// memory exhaustion in case of protocol hickups. This value must be at least fetch.max.bytes
323    /// + 512 to allow for protocol overhead; the value is adjusted automatically unless the
324    /// configuration property is explicitly set.
325    #[serde(rename = "properties.receive.message.max.bytes")]
326    #[serde_as(as = "Option<DisplayFromStr>")]
327    pub receive_message_max_bytes: Option<usize>,
328
329    #[serde(rename = "properties.statistics.interval.ms")]
330    #[serde_as(as = "Option<DisplayFromStr>")]
331    pub statistics_interval_ms: Option<usize>,
332
333    /// Client identifier
334    #[serde(rename = "properties.client.id")]
335    #[serde_as(as = "Option<DisplayFromStr>")]
336    pub client_id: Option<String>,
337
338    #[serde(rename = "properties.enable.ssl.certificate.verification")]
339    #[serde_as(as = "Option<DisplayFromStr>")]
340    pub enable_ssl_certificate_verification: Option<bool>,
341
342    #[serde(
343        rename = "properties.socket.keepalive.enable",
344        default = "default_socket_keepalive_enable"
345    )]
346    #[serde_as(as = "DisplayFromStr")]
347    pub socket_keepalive_enable: bool,
348}
349
350impl RdKafkaPropertiesCommon {
351    pub(crate) fn set_client(&self, c: &mut rdkafka::ClientConfig) {
352        if let Some(v) = self.statistics_interval_ms {
353            c.set("statistics.interval.ms", v.to_string());
354        }
355        if let Some(v) = self.message_max_bytes {
356            c.set("message.max.bytes", v.to_string());
357        }
358        if let Some(v) = self.receive_message_max_bytes {
359            c.set("receive.message.max.bytes", v.to_string());
360        }
361        if let Some(v) = self.client_id.as_ref() {
362            c.set("client.id", v);
363        }
364        if let Some(v) = self.enable_ssl_certificate_verification {
365            c.set("enable.ssl.certificate.verification", v.to_string());
366        }
367        c.set(
368            "socket.keepalive.enable",
369            self.socket_keepalive_enable.to_string(),
370        );
371    }
372}
373
374impl KafkaConnectionProps {
375    #[cfg(test)]
376    pub fn test_default() -> Self {
377        Self {
378            brokers: "localhost:9092".to_owned(),
379            security_protocol: None,
380            ssl_ca_location: None,
381            ssl_certificate_location: None,
382            ssl_key_location: None,
383            ssl_ca_pem: None,
384            ssl_certificate_pem: None,
385            ssl_key_pem: None,
386            ssl_key_password: None,
387            ssl_endpoint_identification_algorithm: None,
388            sasl_mechanism: None,
389            sasl_username: None,
390            sasl_password: None,
391            sasl_kerberos_service_name: None,
392            sasl_kerberos_keytab: None,
393            sasl_kerberos_principal: None,
394            sasl_kerberos_kinit_cmd: None,
395            sasl_kerberos_min_time_before_relogin: None,
396            sasl_oathbearer_config: None,
397        }
398    }
399
400    pub(crate) fn set_security_properties(&self, config: &mut ClientConfig) {
401        // AWS_MSK_IAM
402        if self.is_aws_msk_iam() {
403            config.set("security.protocol", "SASL_SSL");
404            config.set("sasl.mechanism", "OAUTHBEARER");
405            return;
406        }
407
408        // Security protocol
409        if let Some(security_protocol) = self.security_protocol.as_ref() {
410            config.set("security.protocol", security_protocol);
411        }
412
413        // SSL
414        if let Some(ssl_ca_location) = self.ssl_ca_location.as_ref() {
415            config.set("ssl.ca.location", ssl_ca_location);
416        }
417        if let Some(ssl_ca_pem) = self.ssl_ca_pem.as_ref() {
418            config.set("ssl.ca.pem", ssl_ca_pem);
419        }
420        if let Some(ssl_certificate_location) = self.ssl_certificate_location.as_ref() {
421            config.set("ssl.certificate.location", ssl_certificate_location);
422        }
423        if let Some(ssl_certificate_pem) = self.ssl_certificate_pem.as_ref() {
424            config.set("ssl.certificate.pem", ssl_certificate_pem);
425        }
426        if let Some(ssl_key_location) = self.ssl_key_location.as_ref() {
427            config.set("ssl.key.location", ssl_key_location);
428        }
429        if let Some(ssl_key_pem) = self.ssl_key_pem.as_ref() {
430            config.set("ssl.key.pem", ssl_key_pem);
431        }
432        if let Some(ssl_key_password) = self.ssl_key_password.as_ref() {
433            config.set("ssl.key.password", ssl_key_password);
434        }
435        if let Some(ssl_endpoint_identification_algorithm) =
436            self.ssl_endpoint_identification_algorithm.as_ref()
437        {
438            // accept only `none` and `http` here, let the sdk do the check
439            config.set(
440                "ssl.endpoint.identification.algorithm",
441                ssl_endpoint_identification_algorithm,
442            );
443        }
444
445        // SASL mechanism
446        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref() {
447            config.set("sasl.mechanism", sasl_mechanism);
448        }
449
450        // SASL/PLAIN & SASL/SCRAM
451        if let Some(sasl_username) = self.sasl_username.as_ref() {
452            config.set("sasl.username", sasl_username);
453        }
454        if let Some(sasl_password) = self.sasl_password.as_ref() {
455            config.set("sasl.password", sasl_password);
456        }
457
458        // SASL/GSSAPI
459        if let Some(sasl_kerberos_service_name) = self.sasl_kerberos_service_name.as_ref() {
460            config.set("sasl.kerberos.service.name", sasl_kerberos_service_name);
461        }
462        if let Some(sasl_kerberos_keytab) = self.sasl_kerberos_keytab.as_ref() {
463            config.set("sasl.kerberos.keytab", sasl_kerberos_keytab);
464        }
465        if let Some(sasl_kerberos_principal) = self.sasl_kerberos_principal.as_ref() {
466            config.set("sasl.kerberos.principal", sasl_kerberos_principal);
467        }
468        if let Some(sasl_kerberos_kinit_cmd) = self.sasl_kerberos_kinit_cmd.as_ref() {
469            config.set("sasl.kerberos.kinit.cmd", sasl_kerberos_kinit_cmd);
470        }
471        if let Some(sasl_kerberos_min_time_before_relogin) =
472            self.sasl_kerberos_min_time_before_relogin.as_ref()
473        {
474            config.set(
475                "sasl.kerberos.min.time.before.relogin",
476                sasl_kerberos_min_time_before_relogin,
477            );
478        }
479
480        // SASL/OAUTHBEARER
481        if let Some(sasl_oathbearer_config) = self.sasl_oathbearer_config.as_ref() {
482            config.set("sasl.oauthbearer.config", sasl_oathbearer_config);
483        }
484        // Currently, we only support unsecured OAUTH.
485        config.set("enable.sasl.oauthbearer.unsecure.jwt", "true");
486    }
487
488    pub(crate) fn is_aws_msk_iam(&self) -> bool {
489        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref()
490            && sasl_mechanism == AWS_MSK_IAM_AUTH
491        {
492            true
493        } else {
494            false
495        }
496    }
497}
498
499#[derive(Clone, Debug, Deserialize, WithOptions)]
500pub struct PulsarCommon {
501    #[serde(rename = "topic", alias = "pulsar.topic")]
502    pub topic: String,
503
504    #[serde(rename = "service.url", alias = "pulsar.service.url")]
505    pub service_url: String,
506
507    #[serde(rename = "auth.token")]
508    pub auth_token: Option<String>,
509}
510
511impl EnforceSecret for PulsarCommon {
512    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
513        "pulsar.auth.token",
514    };
515}
516
517#[derive(Clone, Debug, Deserialize, WithOptions)]
518pub struct PulsarOauthCommon {
519    #[serde(rename = "oauth.issuer.url")]
520    pub issuer_url: String,
521
522    #[serde(rename = "oauth.credentials.url")]
523    pub credentials_url: String,
524
525    #[serde(rename = "oauth.audience")]
526    pub audience: String,
527
528    #[serde(rename = "oauth.scope")]
529    pub scope: Option<String>,
530}
531
532fn create_credential_temp_file(credentials: &[u8]) -> std::io::Result<NamedTempFile> {
533    let mut f = NamedTempFile::new()?;
534    f.write_all(credentials)?;
535    f.as_file().sync_all()?;
536    Ok(f)
537}
538
539impl PulsarCommon {
540    pub(crate) async fn build_client(
541        &self,
542        oauth: &Option<PulsarOauthCommon>,
543        aws_auth_props: &AwsAuthProps,
544    ) -> ConnectorResult<Pulsar<TokioExecutor>> {
545        let mut pulsar_builder = Pulsar::builder(&self.service_url, TokioExecutor);
546        let mut temp_file = None;
547        if let Some(oauth) = oauth.as_ref() {
548            let url = Url::parse(&oauth.credentials_url)?;
549            match url.scheme() {
550                "s3" => {
551                    let credentials = load_file_descriptor_from_s3(&url, aws_auth_props).await?;
552                    temp_file = Some(
553                        create_credential_temp_file(&credentials)
554                            .context("failed to create temp file for pulsar credentials")?,
555                    );
556                }
557                "file" => {}
558                _ => {
559                    bail!("invalid credentials_url, only file url and s3 url are supported",);
560                }
561            }
562
563            let auth_params = OAuth2Params {
564                issuer_url: oauth.issuer_url.clone(),
565                credentials_url: if temp_file.is_none() {
566                    oauth.credentials_url.clone()
567                } else {
568                    let mut raw_path = temp_file
569                        .as_ref()
570                        .unwrap()
571                        .path()
572                        .to_str()
573                        .unwrap()
574                        .to_owned();
575                    raw_path.insert_str(0, "file://");
576                    raw_path
577                },
578                audience: Some(oauth.audience.clone()),
579                scope: oauth.scope.clone(),
580            };
581
582            pulsar_builder = pulsar_builder
583                .with_auth_provider(OAuth2Authentication::client_credentials(auth_params));
584        } else if let Some(auth_token) = &self.auth_token {
585            pulsar_builder = pulsar_builder.with_auth(Authentication {
586                name: "token".to_owned(),
587                data: Vec::from(auth_token.as_str()),
588            });
589        }
590
591        let res = pulsar_builder.build().await.map_err(|e| anyhow!(e))?;
592        drop(temp_file);
593        Ok(res)
594    }
595}
596
597#[derive(Deserialize, Debug, Clone, WithOptions)]
598pub struct KinesisCommon {
599    #[serde(rename = "stream", alias = "kinesis.stream.name")]
600    pub stream_name: String,
601    #[serde(rename = "aws.region", alias = "kinesis.stream.region")]
602    pub stream_region: String,
603    #[serde(rename = "endpoint", alias = "kinesis.endpoint")]
604    pub endpoint: Option<String>,
605    #[serde(
606        rename = "aws.credentials.access_key_id",
607        alias = "kinesis.credentials.access"
608    )]
609    pub credentials_access_key: Option<String>,
610    #[serde(
611        rename = "aws.credentials.secret_access_key",
612        alias = "kinesis.credentials.secret"
613    )]
614    pub credentials_secret_access_key: Option<String>,
615    #[serde(
616        rename = "aws.credentials.session_token",
617        alias = "kinesis.credentials.session_token"
618    )]
619    pub session_token: Option<String>,
620    #[serde(rename = "aws.credentials.role.arn", alias = "kinesis.assumerole.arn")]
621    pub assume_role_arn: Option<String>,
622    #[serde(
623        rename = "aws.credentials.role.external_id",
624        alias = "kinesis.assumerole.external_id"
625    )]
626    pub assume_role_external_id: Option<String>,
627}
628
629impl EnforceSecret for KinesisCommon {
630    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
631        "kinesis.credentials.access",
632        "kinesis.credentials.secret",
633        "kinesis.credentials.session_token",
634    };
635}
636
637impl KinesisCommon {
638    pub(crate) async fn build_client(&self) -> ConnectorResult<KinesisClient> {
639        let config = AwsAuthProps {
640            region: Some(self.stream_region.clone()),
641            endpoint: self.endpoint.clone(),
642            access_key: self.credentials_access_key.clone(),
643            secret_key: self.credentials_secret_access_key.clone(),
644            session_token: self.session_token.clone(),
645            arn: self.assume_role_arn.clone(),
646            external_id: self.assume_role_external_id.clone(),
647            profile: Default::default(),
648            msk_signer_timeout_sec: Default::default(),
649        };
650        let aws_config = config.build_config().await?;
651        let mut builder = aws_sdk_kinesis::config::Builder::from(&aws_config);
652        if let Some(endpoint) = &config.endpoint {
653            builder = builder.endpoint_url(endpoint);
654        }
655        Ok(KinesisClient::from_conf(builder.build()))
656    }
657}
658
659#[serde_as]
660#[derive(Deserialize, Debug, Clone, WithOptions)]
661pub struct NatsCommon {
662    #[serde(rename = "server_url")]
663    pub server_url: String,
664    #[serde(rename = "subject")]
665    pub subject: String,
666    #[serde(rename = "connect_mode")]
667    pub connect_mode: String,
668    #[serde(rename = "username")]
669    pub user: Option<String>,
670    #[serde(rename = "password")]
671    pub password: Option<String>,
672    #[serde(rename = "jwt")]
673    pub jwt: Option<String>,
674    #[serde(rename = "nkey")]
675    pub nkey: Option<String>,
676    #[serde(rename = "max_bytes")]
677    #[serde_as(as = "Option<DisplayFromStr>")]
678    pub max_bytes: Option<i64>,
679    #[serde(rename = "max_messages")]
680    #[serde_as(as = "Option<DisplayFromStr>")]
681    pub max_messages: Option<i64>,
682    #[serde(rename = "max_messages_per_subject")]
683    #[serde_as(as = "Option<DisplayFromStr>")]
684    pub max_messages_per_subject: Option<i64>,
685    #[serde(rename = "max_consumers")]
686    #[serde_as(as = "Option<DisplayFromStr>")]
687    pub max_consumers: Option<i32>,
688    #[serde(rename = "max_message_size")]
689    #[serde_as(as = "Option<DisplayFromStr>")]
690    pub max_message_size: Option<i32>,
691}
692
693impl EnforceSecret for NatsCommon {
694    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
695        "password",
696        "jwt",
697        "nkey",
698    };
699}
700
701impl NatsCommon {
702    pub(crate) async fn build_client(&self) -> ConnectorResult<async_nats::Client> {
703        let mut connect_options = async_nats::ConnectOptions::new();
704        match self.connect_mode.as_str() {
705            "user_and_password" => {
706                if let (Some(v_user), Some(v_password)) =
707                    (self.user.as_ref(), self.password.as_ref())
708                {
709                    connect_options =
710                        connect_options.user_and_password(v_user.into(), v_password.into())
711                } else {
712                    bail!("nats connect mode is user_and_password, but user or password is empty");
713                }
714            }
715
716            "credential" => {
717                if let (Some(v_nkey), Some(v_jwt)) = (self.nkey.as_ref(), self.jwt.as_ref()) {
718                    connect_options = connect_options
719                        .credentials(&self.create_credential(v_nkey, v_jwt)?)
720                        .expect("failed to parse static creds")
721                } else {
722                    bail!("nats connect mode is credential, but nkey or jwt is empty");
723                }
724            }
725            "plain" => {}
726            _ => {
727                bail!("nats connect mode only accepts user_and_password/credential/plain");
728            }
729        };
730
731        let servers = self.server_url.split(',').collect::<Vec<&str>>();
732        let client = connect_options
733            .connect(
734                servers
735                    .iter()
736                    .map(|url| url.parse())
737                    .collect::<Result<Vec<async_nats::ServerAddr>, _>>()?,
738            )
739            .await
740            .context("build nats client error")
741            .map_err(SinkError::Nats)?;
742        Ok(client)
743    }
744
745    pub(crate) async fn build_context(&self) -> ConnectorResult<jetstream::Context> {
746        let client = self.build_client().await?;
747        let jetstream = async_nats::jetstream::new(client);
748        Ok(jetstream)
749    }
750
751    pub(crate) async fn build_consumer(
752        &self,
753        stream: String,
754        durable_consumer_name: String,
755        split_id: String,
756        start_sequence: NatsOffset,
757        mut config: jetstream::consumer::pull::Config,
758    ) -> ConnectorResult<
759        async_nats::jetstream::consumer::Consumer<async_nats::jetstream::consumer::pull::Config>,
760    > {
761        let context = self.build_context().await?;
762        let stream = self.build_or_get_stream(context.clone(), stream).await?;
763        let subject_name = self
764            .subject
765            .replace(',', "-")
766            .replace(['.', '>', '*', ' ', '\t'], "_");
767        let name = format!("risingwave-consumer-{}-{}", subject_name, split_id);
768
769        let deliver_policy = match start_sequence {
770            NatsOffset::Earliest => DeliverPolicy::All,
771            NatsOffset::Latest => DeliverPolicy::New,
772            NatsOffset::SequenceNumber(v) => {
773                // for compatibility, we do not write to any state table now
774                let parsed = v
775                    .parse::<u64>()
776                    .context("failed to parse nats offset as sequence number")?;
777                DeliverPolicy::ByStartSequence {
778                    start_sequence: 1 + parsed,
779                }
780            }
781            NatsOffset::Timestamp(v) => DeliverPolicy::ByStartTime {
782                start_time: OffsetDateTime::from_unix_timestamp_nanos(v as i128 * 1_000_000)
783                    .context("invalid timestamp for nats offset")?,
784            },
785            NatsOffset::None => DeliverPolicy::All,
786        };
787
788        let consumer = match stream.get_consumer(&name).await {
789            Ok(consumer) => consumer,
790            _ => {
791                stream
792                    .get_or_create_consumer(&name, {
793                        config.deliver_policy = deliver_policy;
794                        config.durable_name = Some(durable_consumer_name);
795                        config.filter_subjects =
796                            self.subject.split(',').map(|s| s.to_owned()).collect();
797                        config
798                    })
799                    .await?
800            }
801        };
802        Ok(consumer)
803    }
804
805    pub(crate) async fn build_or_get_stream(
806        &self,
807        jetstream: jetstream::Context,
808        stream: String,
809    ) -> ConnectorResult<jetstream::stream::Stream> {
810        let subjects: Vec<String> = self.subject.split(',').map(|s| s.to_owned()).collect();
811        if let Ok(mut stream_instance) = jetstream.get_stream(&stream).await {
812            tracing::info!(
813                "load existing nats stream ({:?}) with config {:?}",
814                stream,
815                stream_instance.info().await?
816            );
817            return Ok(stream_instance);
818        }
819
820        let mut config = jetstream::stream::Config {
821            name: stream.clone(),
822            max_bytes: 1000000,
823            subjects,
824            ..Default::default()
825        };
826        if let Some(v) = self.max_bytes {
827            config.max_bytes = v;
828        }
829        if let Some(v) = self.max_messages {
830            config.max_messages = v;
831        }
832        if let Some(v) = self.max_messages_per_subject {
833            config.max_messages_per_subject = v;
834        }
835        if let Some(v) = self.max_consumers {
836            config.max_consumers = v;
837        }
838        if let Some(v) = self.max_message_size {
839            config.max_message_size = v;
840        }
841        tracing::info!(
842            "create nats stream ({:?}) with config {:?}",
843            &stream,
844            config
845        );
846        let stream = jetstream.get_or_create_stream(config).await?;
847        Ok(stream)
848    }
849
850    pub(crate) fn create_credential(&self, seed: &str, jwt: &str) -> ConnectorResult<String> {
851        let creds = format!(
852            "-----BEGIN NATS USER JWT-----\n{}\n------END NATS USER JWT------\n\n\
853                         ************************* IMPORTANT *************************\n\
854                         NKEY Seed printed below can be used to sign and prove identity.\n\
855                         NKEYs are sensitive and should be treated as secrets.\n\n\
856                         -----BEGIN USER NKEY SEED-----\n{}\n------END USER NKEY SEED------\n\n\
857                         *************************************************************",
858            jwt, seed
859        );
860        Ok(creds)
861    }
862}
863
864pub(crate) fn load_certs(
865    certificates: &str,
866) -> ConnectorResult<Vec<rustls_pki_types::CertificateDer<'static>>> {
867    let cert_bytes = if let Some(path) = certificates.strip_prefix("fs://") {
868        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
869    } else {
870        certificates.as_bytes().to_owned()
871    };
872
873    rustls_pemfile::certs(&mut cert_bytes.as_slice())
874        .map(|cert| Ok(cert?))
875        .collect()
876}
877
878pub(crate) fn load_private_key(
879    certificate: &str,
880) -> ConnectorResult<rustls_pki_types::PrivateKeyDer<'static>> {
881    let cert_bytes = if let Some(path) = certificate.strip_prefix("fs://") {
882        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
883    } else {
884        certificate.as_bytes().to_owned()
885    };
886
887    let cert = rustls_pemfile::pkcs8_private_keys(&mut cert_bytes.as_slice())
888        .next()
889        .ok_or_else(|| anyhow!("No private key found"))?;
890    Ok(cert?.into())
891}
892
893#[serde_as]
894#[derive(Deserialize, Debug, Clone, WithOptions)]
895pub struct MongodbCommon {
896    /// The URL of MongoDB
897    #[serde(rename = "mongodb.url")]
898    pub connect_uri: String,
899    /// The collection name where data should be written to or read from. For sinks, the format is
900    /// `db_name.collection_name`. Data can also be written to dynamic collections, see `collection.name.field`
901    /// for more information.
902    #[serde(rename = "collection.name")]
903    pub collection_name: String,
904}
905
906impl EnforceSecret for MongodbCommon {
907    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
908        "mongodb.url"
909    };
910}
911
912impl MongodbCommon {
913    pub(crate) async fn build_client(&self) -> ConnectorResult<mongodb::Client> {
914        let client = mongodb::Client::with_uri_str(&self.connect_uri).await?;
915
916        Ok(client)
917    }
918}