risingwave_connector/connector_common/
common.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::hash::Hash;
17use std::io::Write;
18use std::path::Path;
19use std::sync::{Arc, LazyLock, Weak};
20use std::time::Duration;
21
22use anyhow::{Context, anyhow};
23use async_nats::jetstream::consumer::DeliverPolicy;
24use async_nats::jetstream::{self};
25use aws_sdk_kinesis::Client as KinesisClient;
26use aws_sdk_kinesis::config::{AsyncSleep, SharedAsyncSleep, Sleep};
27use moka::future::Cache as MokaCache;
28use moka::ops::compute::Op;
29use phf::{Set, phf_set};
30use pulsar::authentication::oauth2::{OAuth2Authentication, OAuth2Params};
31use pulsar::{Authentication, Pulsar, TokioExecutor};
32use rdkafka::ClientConfig;
33use risingwave_common::bail;
34use serde::Deserialize;
35use serde_with::json::JsonString;
36use serde_with::{DisplayFromStr, serde_as};
37use tempfile::NamedTempFile;
38use time::OffsetDateTime;
39use url::Url;
40use with_options::WithOptions;
41
42use crate::aws_utils::load_file_descriptor_from_s3;
43use crate::deserialize_duration_from_string;
44use crate::enforce_secret::EnforceSecret;
45use crate::error::ConnectorResult;
46use crate::sink::SinkError;
47use crate::source::nats::source::NatsOffset;
48
49pub const PRIVATE_LINK_BROKER_REWRITE_MAP_KEY: &str = "broker.rewrite.endpoints";
50pub const PRIVATE_LINK_TARGETS_KEY: &str = "privatelink.targets";
51
52const AWS_MSK_IAM_AUTH: &str = "AWS_MSK_IAM";
53
54/// The environment variable to disable using default credential from environment.
55/// It's recommended to set this variable to `true` in cloud hosting environment.
56pub const DISABLE_DEFAULT_CREDENTIAL: &str = "DISABLE_DEFAULT_CREDENTIAL";
57
58#[derive(Debug, Clone, Deserialize)]
59pub struct AwsPrivateLinkItem {
60    pub az_id: Option<String>,
61    pub port: u16,
62}
63
64use aws_config::default_provider::region::DefaultRegionChain;
65use aws_config::sts::AssumeRoleProvider;
66use aws_credential_types::provider::SharedCredentialsProvider;
67use aws_types::SdkConfig;
68use aws_types::region::Region;
69use risingwave_common::util::env_var::env_var_is_true;
70
71/// A flatten config map for aws auth.
72#[derive(Deserialize, Debug, Clone, WithOptions, PartialEq)]
73pub struct AwsAuthProps {
74    #[serde(rename = "aws.region", alias = "region", alias = "s3.region")]
75    pub region: Option<String>,
76
77    #[serde(
78        rename = "aws.endpoint_url",
79        alias = "endpoint_url",
80        alias = "endpoint",
81        alias = "s3.endpoint"
82    )]
83    pub endpoint: Option<String>,
84    #[serde(
85        rename = "aws.credentials.access_key_id",
86        alias = "access_key",
87        alias = "s3.access.key"
88    )]
89    pub access_key: Option<String>,
90    #[serde(
91        rename = "aws.credentials.secret_access_key",
92        alias = "secret_key",
93        alias = "s3.secret.key"
94    )]
95    pub secret_key: Option<String>,
96    #[serde(rename = "aws.credentials.session_token", alias = "session_token")]
97    pub session_token: Option<String>,
98    /// IAM role
99    #[serde(rename = "aws.credentials.role.arn", alias = "arn")]
100    pub arn: Option<String>,
101    /// external ID in IAM role trust policy
102    #[serde(rename = "aws.credentials.role.external_id", alias = "external_id")]
103    pub external_id: Option<String>,
104    #[serde(rename = "aws.profile", alias = "profile")]
105    pub profile: Option<String>,
106    #[serde(rename = "aws.msk.signer_timeout_sec")]
107    pub msk_signer_timeout_sec: Option<u64>,
108}
109
110impl EnforceSecret for AwsAuthProps {
111    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
112        "access_key",
113        "aws.credentials.access_key_id",
114        "s3.access.key",
115        "secret_key",
116        "aws.credentials.secret_access_key",
117        "s3.secret.key",
118        "session_token",
119        "aws.credentials.session_token",
120    };
121}
122
123impl AwsAuthProps {
124    async fn build_region(&self) -> ConnectorResult<Region> {
125        if let Some(region_name) = &self.region {
126            Ok(Region::new(region_name.clone()))
127        } else {
128            let mut region_chain = DefaultRegionChain::builder();
129            if let Some(profile_name) = &self.profile {
130                region_chain = region_chain.profile_name(profile_name);
131            }
132
133            Ok(region_chain
134                .build()
135                .region()
136                .await
137                .context("region should be provided")?)
138        }
139    }
140
141    async fn build_credential_provider(&self) -> ConnectorResult<SharedCredentialsProvider> {
142        if self.access_key.is_some() && self.secret_key.is_some() {
143            Ok(SharedCredentialsProvider::new(
144                aws_credential_types::Credentials::from_keys(
145                    self.access_key.as_ref().unwrap(),
146                    self.secret_key.as_ref().unwrap(),
147                    self.session_token.clone(),
148                ),
149            ))
150        } else if !env_var_is_true(DISABLE_DEFAULT_CREDENTIAL) {
151            Ok(SharedCredentialsProvider::new(
152                aws_config::default_provider::credentials::default_provider().await,
153            ))
154        } else {
155            bail!("Both \"access_key\" and \"secret_key\" are required.")
156        }
157    }
158
159    async fn with_role_provider(
160        &self,
161        credential: SharedCredentialsProvider,
162    ) -> ConnectorResult<SharedCredentialsProvider> {
163        if let Some(role_name) = &self.arn {
164            let region = self.build_region().await?;
165            let mut role = AssumeRoleProvider::builder(role_name)
166                .session_name("RisingWave")
167                .region(region);
168            if let Some(id) = &self.external_id {
169                role = role.external_id(id);
170            }
171            let provider = role.build_from_provider(credential).await;
172            Ok(SharedCredentialsProvider::new(provider))
173        } else {
174            Ok(credential)
175        }
176    }
177
178    pub async fn build_config(&self) -> ConnectorResult<SdkConfig> {
179        let region = self.build_region().await?;
180        let credentials_provider = self
181            .with_role_provider(self.build_credential_provider().await?)
182            .await?;
183        let mut config_loader = aws_config::from_env()
184            .region(region)
185            .credentials_provider(credentials_provider);
186
187        if let Some(endpoint) = self.endpoint.as_ref() {
188            config_loader = config_loader.endpoint_url(endpoint);
189        }
190
191        Ok(config_loader.load().await)
192    }
193}
194
195#[serde_as]
196#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
197pub struct KafkaConnectionProps {
198    #[serde(rename = "properties.bootstrap.server", alias = "kafka.brokers")]
199    pub brokers: String,
200
201    /// Security protocol used for RisingWave to communicate with Kafka brokers. Could be
202    /// PLAINTEXT, SSL, `SASL_PLAINTEXT` or `SASL_SSL`.
203    #[serde(rename = "properties.security.protocol")]
204    #[with_option(allow_alter_on_fly)]
205    security_protocol: Option<String>,
206
207    #[serde(rename = "properties.ssl.endpoint.identification.algorithm")]
208    #[with_option(allow_alter_on_fly)]
209    ssl_endpoint_identification_algorithm: Option<String>,
210
211    // For the properties below, please refer to [librdkafka](https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md) for more information.
212    /// Path to CA certificate file for verifying the broker's key.
213    #[serde(rename = "properties.ssl.ca.location")]
214    ssl_ca_location: Option<String>,
215
216    /// CA certificate string (PEM format) for verifying the broker's key.
217    #[serde(rename = "properties.ssl.ca.pem")]
218    ssl_ca_pem: Option<String>,
219
220    /// Path to client's certificate file (PEM).
221    #[serde(rename = "properties.ssl.certificate.location")]
222    ssl_certificate_location: Option<String>,
223
224    /// Client's public key string (PEM format) used for authentication.
225    #[serde(rename = "properties.ssl.certificate.pem")]
226    ssl_certificate_pem: Option<String>,
227
228    /// Path to client's private key file (PEM).
229    #[serde(rename = "properties.ssl.key.location")]
230    ssl_key_location: Option<String>,
231
232    /// Client's private key string (PEM format) used for authentication.
233    #[serde(rename = "properties.ssl.key.pem")]
234    ssl_key_pem: Option<String>,
235
236    /// Passphrase of client's private key.
237    #[serde(rename = "properties.ssl.key.password")]
238    ssl_key_password: Option<String>,
239
240    /// SASL mechanism if SASL is enabled. Currently support PLAIN, SCRAM, GSSAPI, and `AWS_MSK_IAM`.
241    #[serde(rename = "properties.sasl.mechanism")]
242    #[with_option(allow_alter_on_fly)]
243    sasl_mechanism: Option<String>,
244
245    /// SASL username for SASL/PLAIN and SASL/SCRAM.
246    #[serde(rename = "properties.sasl.username")]
247    #[with_option(allow_alter_on_fly)]
248    sasl_username: Option<String>,
249
250    /// SASL password for SASL/PLAIN and SASL/SCRAM.
251    #[serde(rename = "properties.sasl.password")]
252    #[with_option(allow_alter_on_fly)]
253    sasl_password: Option<String>,
254
255    /// Kafka server's Kerberos principal name under SASL/GSSAPI, not including /hostname@REALM.
256    #[serde(rename = "properties.sasl.kerberos.service.name")]
257    sasl_kerberos_service_name: Option<String>,
258
259    /// Path to client's Kerberos keytab file under SASL/GSSAPI.
260    #[serde(rename = "properties.sasl.kerberos.keytab")]
261    sasl_kerberos_keytab: Option<String>,
262
263    /// Client's Kerberos principal name under SASL/GSSAPI.
264    #[serde(rename = "properties.sasl.kerberos.principal")]
265    sasl_kerberos_principal: Option<String>,
266
267    /// Shell command to refresh or acquire the client's Kerberos ticket under SASL/GSSAPI.
268    #[serde(rename = "properties.sasl.kerberos.kinit.cmd")]
269    sasl_kerberos_kinit_cmd: Option<String>,
270
271    /// Minimum time in milliseconds between key refresh attempts under SASL/GSSAPI.
272    #[serde(rename = "properties.sasl.kerberos.min.time.before.relogin")]
273    sasl_kerberos_min_time_before_relogin: Option<String>,
274
275    /// Configurations for SASL/OAUTHBEARER.
276    #[serde(rename = "properties.sasl.oauthbearer.config")]
277    sasl_oathbearer_config: Option<String>,
278}
279
280impl EnforceSecret for KafkaConnectionProps {
281    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
282        "properties.ssl.key.pem",
283        "properties.ssl.key.password",
284        "properties.sasl.password",
285    };
286}
287
288#[serde_as]
289#[derive(Debug, Clone, Deserialize, WithOptions)]
290pub struct KafkaCommon {
291    // connection related props are moved to `KafkaConnection`
292    #[serde(rename = "topic", alias = "kafka.topic")]
293    pub topic: String,
294
295    #[serde(
296        rename = "properties.sync.call.timeout",
297        deserialize_with = "deserialize_duration_from_string",
298        default = "default_kafka_sync_call_timeout"
299    )]
300    #[with_option(allow_alter_on_fly)]
301    pub sync_call_timeout: Duration,
302}
303
304#[serde_as]
305#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
306pub struct KafkaPrivateLinkCommon {
307    /// This is generated from `private_link_targets` and `private_link_endpoint` in frontend, instead of given by users.
308    #[serde(rename = "broker.rewrite.endpoints")]
309    #[serde_as(as = "Option<JsonString>")]
310    pub broker_rewrite_map: Option<BTreeMap<String, String>>,
311}
312
313const fn default_kafka_sync_call_timeout() -> Duration {
314    Duration::from_secs(5)
315}
316
317const fn default_socket_keepalive_enable() -> bool {
318    true
319}
320
321#[serde_as]
322#[derive(Debug, Clone, Deserialize, WithOptions)]
323pub struct RdKafkaPropertiesCommon {
324    /// Maximum Kafka protocol request message size. Due to differing framing overhead between
325    /// protocol versions the producer is unable to reliably enforce a strict max message limit at
326    /// produce time and may exceed the maximum size by one message in protocol `ProduceRequests`,
327    /// the broker will enforce the topic's max.message.bytes limit
328    #[serde(rename = "properties.message.max.bytes")]
329    #[serde_as(as = "Option<DisplayFromStr>")]
330    #[with_option(allow_alter_on_fly)]
331    pub message_max_bytes: Option<usize>,
332
333    /// Maximum Kafka protocol response message size. This serves as a safety precaution to avoid
334    /// memory exhaustion in case of protocol hickups. This value must be at least fetch.max.bytes
335    /// + 512 to allow for protocol overhead; the value is adjusted automatically unless the
336    /// configuration property is explicitly set.
337    #[serde(rename = "properties.receive.message.max.bytes")]
338    #[serde_as(as = "Option<DisplayFromStr>")]
339    #[with_option(allow_alter_on_fly)]
340    pub receive_message_max_bytes: Option<usize>,
341
342    #[serde(rename = "properties.statistics.interval.ms")]
343    #[serde_as(as = "Option<DisplayFromStr>")]
344    #[with_option(allow_alter_on_fly)]
345    pub statistics_interval_ms: Option<usize>,
346
347    /// Client identifier
348    #[serde(rename = "properties.client.id")]
349    #[serde_as(as = "Option<DisplayFromStr>")]
350    #[with_option(allow_alter_on_fly)]
351    pub client_id: Option<String>,
352
353    #[serde(rename = "properties.enable.ssl.certificate.verification")]
354    #[serde_as(as = "Option<DisplayFromStr>")]
355    #[with_option(allow_alter_on_fly)]
356    pub enable_ssl_certificate_verification: Option<bool>,
357
358    #[serde(
359        rename = "properties.socket.keepalive.enable",
360        default = "default_socket_keepalive_enable"
361    )]
362    #[serde_as(as = "DisplayFromStr")]
363    pub socket_keepalive_enable: bool,
364}
365
366impl RdKafkaPropertiesCommon {
367    pub(crate) fn set_client(&self, c: &mut rdkafka::ClientConfig) {
368        if let Some(v) = self.statistics_interval_ms {
369            c.set("statistics.interval.ms", v.to_string());
370        }
371        if let Some(v) = self.message_max_bytes {
372            c.set("message.max.bytes", v.to_string());
373        }
374        if let Some(v) = self.receive_message_max_bytes {
375            c.set("receive.message.max.bytes", v.to_string());
376        }
377        if let Some(v) = self.client_id.as_ref() {
378            c.set("client.id", v);
379        }
380        if let Some(v) = self.enable_ssl_certificate_verification {
381            c.set("enable.ssl.certificate.verification", v.to_string());
382        }
383        c.set(
384            "socket.keepalive.enable",
385            self.socket_keepalive_enable.to_string(),
386        );
387    }
388}
389
390impl KafkaConnectionProps {
391    #[cfg(test)]
392    pub fn test_default() -> Self {
393        Self {
394            brokers: "localhost:9092".to_owned(),
395            security_protocol: None,
396            ssl_ca_location: None,
397            ssl_certificate_location: None,
398            ssl_key_location: None,
399            ssl_ca_pem: None,
400            ssl_certificate_pem: None,
401            ssl_key_pem: None,
402            ssl_key_password: None,
403            ssl_endpoint_identification_algorithm: None,
404            sasl_mechanism: None,
405            sasl_username: None,
406            sasl_password: None,
407            sasl_kerberos_service_name: None,
408            sasl_kerberos_keytab: None,
409            sasl_kerberos_principal: None,
410            sasl_kerberos_kinit_cmd: None,
411            sasl_kerberos_min_time_before_relogin: None,
412            sasl_oathbearer_config: None,
413        }
414    }
415
416    pub(crate) fn set_security_properties(&self, config: &mut ClientConfig) {
417        // AWS_MSK_IAM
418        if self.is_aws_msk_iam() {
419            config.set("security.protocol", "SASL_SSL");
420            config.set("sasl.mechanism", "OAUTHBEARER");
421            return;
422        }
423
424        // Security protocol
425        if let Some(security_protocol) = self.security_protocol.as_ref() {
426            config.set("security.protocol", security_protocol);
427        }
428
429        // SSL
430        if let Some(ssl_ca_location) = self.ssl_ca_location.as_ref() {
431            config.set("ssl.ca.location", ssl_ca_location);
432        }
433        if let Some(ssl_ca_pem) = self.ssl_ca_pem.as_ref() {
434            config.set("ssl.ca.pem", ssl_ca_pem);
435        }
436        if let Some(ssl_certificate_location) = self.ssl_certificate_location.as_ref() {
437            config.set("ssl.certificate.location", ssl_certificate_location);
438        }
439        if let Some(ssl_certificate_pem) = self.ssl_certificate_pem.as_ref() {
440            config.set("ssl.certificate.pem", ssl_certificate_pem);
441        }
442        if let Some(ssl_key_location) = self.ssl_key_location.as_ref() {
443            config.set("ssl.key.location", ssl_key_location);
444        }
445        if let Some(ssl_key_pem) = self.ssl_key_pem.as_ref() {
446            config.set("ssl.key.pem", ssl_key_pem);
447        }
448        if let Some(ssl_key_password) = self.ssl_key_password.as_ref() {
449            config.set("ssl.key.password", ssl_key_password);
450        }
451        if let Some(ssl_endpoint_identification_algorithm) =
452            self.ssl_endpoint_identification_algorithm.as_ref()
453        {
454            // accept only `none` and `http` here, let the sdk do the check
455            config.set(
456                "ssl.endpoint.identification.algorithm",
457                ssl_endpoint_identification_algorithm,
458            );
459        }
460
461        // SASL mechanism
462        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref() {
463            config.set("sasl.mechanism", sasl_mechanism);
464        }
465
466        // SASL/PLAIN & SASL/SCRAM
467        if let Some(sasl_username) = self.sasl_username.as_ref() {
468            config.set("sasl.username", sasl_username);
469        }
470        if let Some(sasl_password) = self.sasl_password.as_ref() {
471            config.set("sasl.password", sasl_password);
472        }
473
474        // SASL/GSSAPI
475        if let Some(sasl_kerberos_service_name) = self.sasl_kerberos_service_name.as_ref() {
476            config.set("sasl.kerberos.service.name", sasl_kerberos_service_name);
477        }
478        if let Some(sasl_kerberos_keytab) = self.sasl_kerberos_keytab.as_ref() {
479            config.set("sasl.kerberos.keytab", sasl_kerberos_keytab);
480        }
481        if let Some(sasl_kerberos_principal) = self.sasl_kerberos_principal.as_ref() {
482            config.set("sasl.kerberos.principal", sasl_kerberos_principal);
483        }
484        if let Some(sasl_kerberos_kinit_cmd) = self.sasl_kerberos_kinit_cmd.as_ref() {
485            config.set("sasl.kerberos.kinit.cmd", sasl_kerberos_kinit_cmd);
486        }
487        if let Some(sasl_kerberos_min_time_before_relogin) =
488            self.sasl_kerberos_min_time_before_relogin.as_ref()
489        {
490            config.set(
491                "sasl.kerberos.min.time.before.relogin",
492                sasl_kerberos_min_time_before_relogin,
493            );
494        }
495
496        // SASL/OAUTHBEARER
497        if let Some(sasl_oathbearer_config) = self.sasl_oathbearer_config.as_ref() {
498            config.set("sasl.oauthbearer.config", sasl_oathbearer_config);
499        }
500        // Currently, we only support unsecured OAUTH.
501        config.set("enable.sasl.oauthbearer.unsecure.jwt", "true");
502    }
503
504    pub(crate) fn is_aws_msk_iam(&self) -> bool {
505        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref()
506            && sasl_mechanism == AWS_MSK_IAM_AUTH
507        {
508            true
509        } else {
510            false
511        }
512    }
513}
514
515#[derive(Clone, Debug, Deserialize, WithOptions)]
516pub struct PulsarCommon {
517    #[serde(rename = "topic", alias = "pulsar.topic")]
518    pub topic: String,
519
520    #[serde(rename = "service.url", alias = "pulsar.service.url")]
521    pub service_url: String,
522
523    #[serde(rename = "auth.token")]
524    pub auth_token: Option<String>,
525}
526
527impl EnforceSecret for PulsarCommon {
528    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
529        "pulsar.auth.token",
530    };
531}
532
533#[derive(Clone, Debug, Deserialize, WithOptions)]
534pub struct PulsarOauthCommon {
535    #[serde(rename = "oauth.issuer.url")]
536    pub issuer_url: String,
537
538    #[serde(rename = "oauth.credentials.url")]
539    pub credentials_url: String,
540
541    #[serde(rename = "oauth.audience")]
542    pub audience: String,
543
544    #[serde(rename = "oauth.scope")]
545    pub scope: Option<String>,
546}
547
548fn create_credential_temp_file(credentials: &[u8]) -> std::io::Result<NamedTempFile> {
549    let mut f = NamedTempFile::new()?;
550    f.write_all(credentials)?;
551    f.as_file().sync_all()?;
552    Ok(f)
553}
554
555impl PulsarCommon {
556    pub(crate) async fn build_client(
557        &self,
558        oauth: &Option<PulsarOauthCommon>,
559        aws_auth_props: &AwsAuthProps,
560    ) -> ConnectorResult<Pulsar<TokioExecutor>> {
561        let mut pulsar_builder = Pulsar::builder(&self.service_url, TokioExecutor);
562        let mut _temp_file = None; // Keep temp file alive
563
564        if let Some(oauth) = oauth.as_ref() {
565            let (credentials_url, temp_file) = self
566                .resolve_pulsar_credentials_url(oauth, aws_auth_props)
567                .await?;
568            _temp_file = temp_file;
569
570            let auth_params = OAuth2Params {
571                issuer_url: oauth.issuer_url.clone(),
572                credentials_url,
573                audience: Some(oauth.audience.clone()),
574                scope: oauth.scope.clone(),
575            };
576
577            pulsar_builder = pulsar_builder
578                .with_auth_provider(OAuth2Authentication::client_credentials(auth_params));
579        } else if let Some(auth_token) = &self.auth_token {
580            pulsar_builder = pulsar_builder.with_auth(Authentication {
581                name: "token".to_owned(),
582                data: Vec::from(auth_token.as_str()),
583            });
584        }
585
586        let res = pulsar_builder.build().await.map_err(|e| anyhow!(e))?;
587        drop(_temp_file); // Explicitly drop temp file after client is built
588        Ok(res)
589    }
590
591    pub(crate) async fn resolve_pulsar_credentials_url(
592        &self,
593        oauth: &PulsarOauthCommon,
594        aws_auth_props: &AwsAuthProps,
595    ) -> ConnectorResult<(String, Option<NamedTempFile>)> {
596        // Try parsing as URL first
597        if let Ok(url) = Url::parse(&oauth.credentials_url) {
598            return self
599                .handle_pulsar_credentials_url(&url, aws_auth_props)
600                .await;
601        }
602
603        // If not a valid URL, check if it's an absolute file path
604        let path = Path::new(&oauth.credentials_url);
605        if !path.is_absolute() {
606            bail!("credentials_url must be a valid URL (s3://, file://) or an absolute file path");
607        }
608
609        // Verify the file exists
610        if !tokio::fs::try_exists(&oauth.credentials_url)
611            .await
612            .unwrap_or(false)
613        {
614            bail!("credentials file does not exist: {}", oauth.credentials_url);
615        }
616
617        // Return absolute path with file:// prefix
618        Ok((format!("file://{}", oauth.credentials_url), None))
619    }
620
621    pub(crate) async fn handle_pulsar_credentials_url(
622        &self,
623        url: &Url,
624        aws_auth_props: &AwsAuthProps,
625    ) -> ConnectorResult<(String, Option<NamedTempFile>)> {
626        match url.scheme() {
627            "s3" => {
628                let credentials = load_file_descriptor_from_s3(url, aws_auth_props).await?;
629                let temp_file = create_credential_temp_file(&credentials)
630                    .context("failed to create temp file for pulsar credentials")?;
631
632                let temp_path = temp_file
633                    .path()
634                    .to_str()
635                    .context("temp file path is not valid UTF-8")?;
636
637                Ok((format!("file://{}", temp_path), Some(temp_file)))
638            }
639            "file" => Ok((url.to_string(), None)),
640            _ => bail!(
641                "invalid credentials_url scheme '{}', only file://, s3://, and absolute file paths are supported",
642                url.scheme()
643            ),
644        }
645    }
646}
647
648#[serde_as]
649#[derive(Deserialize, Debug, Clone, WithOptions)]
650pub struct KinesisCommon {
651    #[serde(rename = "stream", alias = "kinesis.stream.name")]
652    pub stream_name: String,
653    #[serde(rename = "aws.region", alias = "kinesis.stream.region")]
654    pub stream_region: String,
655    #[serde(rename = "endpoint", alias = "kinesis.endpoint")]
656    pub endpoint: Option<String>,
657    #[serde(
658        rename = "aws.credentials.access_key_id",
659        alias = "kinesis.credentials.access"
660    )]
661    pub credentials_access_key: Option<String>,
662    #[serde(
663        rename = "aws.credentials.secret_access_key",
664        alias = "kinesis.credentials.secret"
665    )]
666    pub credentials_secret_access_key: Option<String>,
667    #[serde(
668        rename = "aws.credentials.session_token",
669        alias = "kinesis.credentials.session_token"
670    )]
671    pub session_token: Option<String>,
672    #[serde(rename = "aws.credentials.role.arn", alias = "kinesis.assumerole.arn")]
673    pub assume_role_arn: Option<String>,
674    #[serde(
675        rename = "aws.credentials.role.external_id",
676        alias = "kinesis.assumerole.external_id"
677    )]
678    pub assume_role_external_id: Option<String>,
679
680    // sdk options
681    #[serde(
682        rename = "kinesis.sdk.connect_timeout_ms",
683        default = "kinesis_default_connect_timeout_ms"
684    )]
685    #[serde_as(as = "DisplayFromStr")]
686    pub sdk_connect_timeout_ms: u64,
687
688    #[serde(
689        rename = "kinesis.sdk.read_timeout_ms",
690        default = "kinesis_default_read_timeout_ms"
691    )]
692    #[serde_as(as = "DisplayFromStr")]
693    pub sdk_read_timeout_ms: u64,
694
695    #[serde(
696        rename = "kinesis.sdk.operation_timeout_ms",
697        default = "kinesis_default_operation_timeout_ms"
698    )]
699    #[serde_as(as = "DisplayFromStr")]
700    pub sdk_operation_timeout_ms: u64,
701
702    #[serde(
703        rename = "kinesis.sdk.operation_attempt_timeout_ms",
704        default = "kinesis_default_operation_attempt_timeout_ms"
705    )]
706    #[serde_as(as = "DisplayFromStr")]
707    pub sdk_operation_attempt_timeout_ms: u64,
708
709    #[serde(
710        rename = "kinesis.sdk.max_retry_limit",
711        default = "kinesis_default_max_retry_limit"
712    )]
713    #[serde_as(as = "DisplayFromStr")]
714    pub sdk_max_retry_limit: u32,
715
716    #[serde(
717        rename = "kinesis.sdk.init_backoff_ms",
718        default = "kinesis_default_init_backoff_ms"
719    )]
720    #[serde_as(as = "DisplayFromStr")]
721    pub sdk_init_backoff_ms: u64,
722
723    #[serde(
724        rename = "kinesis.sdk.max_backoff_ms",
725        default = "kinesis_default_max_backoff_ms"
726    )]
727    #[serde_as(as = "DisplayFromStr")]
728    pub sdk_max_backoff_ms: u64,
729}
730
731#[derive(Debug)]
732pub struct KinesisAsyncSleepImpl;
733
734impl AsyncSleep for KinesisAsyncSleepImpl {
735    fn sleep(&self, duration: Duration) -> Sleep {
736        Sleep::new(async move { tokio::time::sleep(duration).await })
737    }
738}
739
740const fn kinesis_default_connect_timeout_ms() -> u64 {
741    10000
742}
743
744const fn kinesis_default_read_timeout_ms() -> u64 {
745    10000
746}
747
748const fn kinesis_default_operation_timeout_ms() -> u64 {
749    10000
750}
751
752const fn kinesis_default_operation_attempt_timeout_ms() -> u64 {
753    10000
754}
755
756const fn kinesis_default_init_backoff_ms() -> u64 {
757    1000
758}
759
760const fn kinesis_default_max_backoff_ms() -> u64 {
761    20000
762}
763
764const fn kinesis_default_max_retry_limit() -> u32 {
765    3
766}
767
768impl EnforceSecret for KinesisCommon {
769    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
770        "kinesis.credentials.access",
771        "kinesis.credentials.secret",
772        "kinesis.credentials.session_token",
773    };
774}
775
776impl KinesisCommon {
777    pub(crate) async fn build_client(&self) -> ConnectorResult<KinesisClient> {
778        let config = AwsAuthProps {
779            region: Some(self.stream_region.clone()),
780            endpoint: self.endpoint.clone(),
781            access_key: self.credentials_access_key.clone(),
782            secret_key: self.credentials_secret_access_key.clone(),
783            session_token: self.session_token.clone(),
784            arn: self.assume_role_arn.clone(),
785            external_id: self.assume_role_external_id.clone(),
786            profile: Default::default(),
787            msk_signer_timeout_sec: Default::default(),
788        };
789        let aws_config = config.build_config().await?;
790        let mut builder = aws_sdk_kinesis::config::Builder::from(&aws_config);
791        {
792            // for timeout and retry config
793            let sleep_impl = SharedAsyncSleep::new(KinesisAsyncSleepImpl);
794            builder.set_sleep_impl(Some(sleep_impl));
795            let timeout_config = aws_smithy_types::timeout::TimeoutConfig::builder()
796                .connect_timeout(Duration::from_millis(self.sdk_connect_timeout_ms))
797                .read_timeout(Duration::from_millis(self.sdk_read_timeout_ms))
798                .operation_timeout(Duration::from_millis(self.sdk_operation_timeout_ms))
799                .operation_attempt_timeout(Duration::from_millis(
800                    self.sdk_operation_attempt_timeout_ms,
801                ))
802                .build();
803            builder.set_timeout_config(Some(timeout_config));
804
805            let retry_config = aws_smithy_types::retry::RetryConfig::standard()
806                .with_initial_backoff(Duration::from_millis(self.sdk_init_backoff_ms))
807                .with_max_backoff(Duration::from_millis(self.sdk_max_backoff_ms))
808                .with_max_attempts(self.sdk_max_retry_limit);
809            builder.set_retry_config(Some(retry_config));
810        }
811        if let Some(endpoint) = &config.endpoint {
812            builder = builder.endpoint_url(endpoint);
813        }
814        Ok(KinesisClient::from_conf(builder.build()))
815    }
816}
817
818/// Connection properties for NATS, used as a cache key for shared clients.
819/// This includes all properties that affect the connection itself (not stream/subject specific).
820#[derive(Debug, Clone, PartialEq, Eq, Hash)]
821pub struct NatsConnectionProps {
822    pub server_url: String,
823    pub connect_mode: String,
824    pub user: Option<String>,
825    pub password: Option<String>,
826    pub jwt: Option<String>,
827    pub nkey: Option<String>,
828}
829
830/// Shared NATS client cache.
831/// Client connections are cached as `Weak` pointers in the cache.
832/// NATS Connector can access this cache to reuse existing client connections,
833/// and avoid exhausting host machine ports.
834/// When reading from the cache, the connector should `upgrade` the weak pointer to an `Arc` reference.
835/// After all strong (Arc) references are dropped, the client connection will be cleaned up.
836/// Cache eviction naturally takes care of the dangling weak pointers.
837pub static SHARED_NATS_CLIENT: LazyLock<MokaCache<NatsConnectionProps, Weak<async_nats::Client>>> =
838    LazyLock::new(|| MokaCache::builder().build());
839
840#[serde_as]
841#[derive(Deserialize, Debug, Clone, WithOptions)]
842pub struct NatsCommon {
843    #[serde(rename = "server_url")]
844    pub server_url: String,
845    #[serde(rename = "subject")]
846    pub subject: String,
847    #[serde(rename = "connect_mode")]
848    pub connect_mode: String,
849    #[serde(rename = "username")]
850    pub user: Option<String>,
851    #[serde(rename = "password")]
852    pub password: Option<String>,
853    #[serde(rename = "jwt")]
854    pub jwt: Option<String>,
855    #[serde(rename = "nkey")]
856    pub nkey: Option<String>,
857    #[serde(rename = "max_bytes")]
858    #[serde_as(as = "Option<DisplayFromStr>")]
859    pub max_bytes: Option<i64>,
860    #[serde(rename = "max_messages")]
861    #[serde_as(as = "Option<DisplayFromStr>")]
862    pub max_messages: Option<i64>,
863    #[serde(rename = "max_messages_per_subject")]
864    #[serde_as(as = "Option<DisplayFromStr>")]
865    pub max_messages_per_subject: Option<i64>,
866    #[serde(rename = "max_consumers")]
867    #[serde_as(as = "Option<DisplayFromStr>")]
868    pub max_consumers: Option<i32>,
869    #[serde(rename = "max_message_size")]
870    #[serde_as(as = "Option<DisplayFromStr>")]
871    pub max_message_size: Option<i32>,
872    #[serde(rename = "allow_create_stream", default)]
873    #[serde_as(as = "DisplayFromStr")]
874    pub allow_create_stream: bool,
875}
876
877impl EnforceSecret for NatsCommon {
878    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
879        "password",
880        "jwt",
881        "nkey",
882    };
883}
884
885impl NatsCommon {
886    /// Extract connection properties that can be used as a cache key.
887    pub fn connection_props(&self) -> NatsConnectionProps {
888        NatsConnectionProps {
889            server_url: self.server_url.clone(),
890            connect_mode: self.connect_mode.clone(),
891            user: self.user.clone(),
892            password: self.password.clone(),
893            jwt: self.jwt.clone(),
894            nkey: self.nkey.clone(),
895        }
896    }
897
898    /// Build a new NATS client without caching.
899    async fn build_client_inner(&self) -> ConnectorResult<async_nats::Client> {
900        let mut connect_options = async_nats::ConnectOptions::new();
901        match self.connect_mode.as_str() {
902            "user_and_password" => {
903                if let (Some(v_user), Some(v_password)) =
904                    (self.user.as_ref(), self.password.as_ref())
905                {
906                    connect_options =
907                        connect_options.user_and_password(v_user.into(), v_password.into())
908                } else {
909                    bail!("nats connect mode is user_and_password, but user or password is empty");
910                }
911            }
912
913            "credential" => {
914                if let (Some(v_nkey), Some(v_jwt)) = (self.nkey.as_ref(), self.jwt.as_ref()) {
915                    connect_options = connect_options
916                        .credentials(&self.create_credential(v_nkey, v_jwt)?)
917                        .expect("failed to parse static creds")
918                } else {
919                    bail!("nats connect mode is credential, but nkey or jwt is empty");
920                }
921            }
922            "plain" => {}
923            _ => {
924                bail!("nats connect mode only accepts user_and_password/credential/plain");
925            }
926        };
927
928        let servers = self.server_url.split(',').collect::<Vec<&str>>();
929        let client = connect_options
930            .connect(
931                servers
932                    .iter()
933                    .map(|url| url.parse())
934                    .collect::<Result<Vec<async_nats::ServerAddr>, _>>()?,
935            )
936            .await
937            .context("build nats client error")
938            .map_err(SinkError::Nats)?;
939        Ok(client)
940    }
941
942    /// Build a NATS client, attempting to reuse an existing cached client if available.
943    /// See `SHARED_NATS_CLIENT` for more details.
944    pub(crate) async fn build_client(&self) -> ConnectorResult<Arc<async_nats::Client>> {
945        let connection_props = self.connection_props();
946        let mut client: Option<Arc<async_nats::Client>> = None;
947
948        SHARED_NATS_CLIENT
949            .entry_by_ref(&connection_props)
950            .and_try_compute_with::<_, _, crate::error::ConnectorError>(|maybe_entry| async {
951                if let Some(entry) = maybe_entry
952                    && let entry_value = entry.into_value()
953                    && let Some(existing_client) = entry_value.upgrade()
954                {
955                    match existing_client.connection_state() {
956                        async_nats::connection::State::Connected => {
957                            tracing::info!("reuse existing nats client for {}", self.server_url);
958                            client = Some(existing_client);
959                            return Ok(Op::Nop);
960                        }
961                        _ => {
962                            tracing::warn!(
963                                server_url = self.server_url,
964                                "existing nats client is not connected",
965                            );
966                        }
967                    }
968                }
969                tracing::info!(
970                    server_url = self.server_url,
971                    "no cached client, or client disconnected, building new nats client"
972                );
973                let new_client = Arc::new(self.build_client_inner().await?);
974                client = Some(new_client.clone());
975                Ok(Op::Put(Arc::downgrade(&new_client)))
976            })
977            .await?;
978
979        Ok(client.expect("client should be set"))
980    }
981
982    pub(crate) async fn build_context(&self) -> ConnectorResult<jetstream::Context> {
983        let client = self.build_client().await?;
984        let jetstream = async_nats::jetstream::new((*client).clone());
985        Ok(jetstream)
986    }
987
988    /// Build a `JetStream` context using a pre-existing client.
989    pub(crate) fn build_context_from_client(
990        client: &Arc<async_nats::Client>,
991    ) -> jetstream::Context {
992        async_nats::jetstream::new((**client).clone())
993    }
994
995    /// Build a NATS `JetStream` consumer.
996    ///
997    /// If `existing_client` is provided, it will be used instead of creating/fetching a new one.
998    /// This allows callers to reuse a client they already hold.
999    pub(crate) async fn build_consumer(
1000        &self,
1001        stream: String,
1002        durable_consumer_name: String,
1003        split_id: String,
1004        start_sequence: NatsOffset,
1005        mut config: jetstream::consumer::pull::Config,
1006        existing_client: Option<Arc<async_nats::Client>>,
1007    ) -> ConnectorResult<(
1008        async_nats::jetstream::consumer::Consumer<async_nats::jetstream::consumer::pull::Config>,
1009        Arc<async_nats::Client>,
1010    )> {
1011        let client = match existing_client {
1012            Some(c) => c,
1013            None => self.build_client().await?,
1014        };
1015        let context = Self::build_context_from_client(&client);
1016        let stream = self.build_or_get_stream(context.clone(), stream).await?;
1017        let subject_name = self
1018            .subject
1019            .replace(',', "-")
1020            .replace(['.', '>', '*', ' ', '\t'], "_");
1021        let name = format!("risingwave-consumer-{}-{}", subject_name, split_id);
1022
1023        let deliver_policy = match start_sequence {
1024            NatsOffset::Earliest => DeliverPolicy::All,
1025            NatsOffset::Latest => DeliverPolicy::New,
1026            NatsOffset::SequenceNumber(v) => {
1027                // for compatibility, we do not write to any state table now
1028                let parsed = v
1029                    .parse::<u64>()
1030                    .context("failed to parse nats offset as sequence number")?;
1031                DeliverPolicy::ByStartSequence {
1032                    start_sequence: 1 + parsed,
1033                }
1034            }
1035            NatsOffset::Timestamp(v) => DeliverPolicy::ByStartTime {
1036                start_time: OffsetDateTime::from_unix_timestamp_nanos(v as i128 * 1_000_000)
1037                    .context("invalid timestamp for nats offset")?,
1038            },
1039            NatsOffset::None => DeliverPolicy::All,
1040        };
1041
1042        let consumer = match stream.get_consumer(&name).await {
1043            Ok(consumer) => consumer,
1044            _ => {
1045                stream
1046                    .get_or_create_consumer(&name, {
1047                        config.deliver_policy = deliver_policy;
1048                        config.durable_name = Some(durable_consumer_name);
1049                        config.filter_subjects =
1050                            self.subject.split(',').map(|s| s.to_owned()).collect();
1051                        config
1052                    })
1053                    .await?
1054            }
1055        };
1056        Ok((consumer, client))
1057    }
1058
1059    pub(crate) async fn build_or_get_stream(
1060        &self,
1061        jetstream: jetstream::Context,
1062        stream_str: String,
1063    ) -> ConnectorResult<jetstream::stream::Stream> {
1064        let subjects: Vec<String> = self.subject.split(',').map(|s| s.to_owned()).collect();
1065
1066        // In `SourceEnumerator`, we may create a stream
1067        // In `SourceReader`, the desired stream MUST exist
1068        if let Ok(mut stream_instance) = jetstream.get_stream(&stream_str).await {
1069            tracing::info!(
1070                "load existing nats stream ({:?}) with config {:?}",
1071                stream_str,
1072                stream_instance.info().await?
1073            );
1074            return Ok(stream_instance);
1075        }
1076
1077        if !self.allow_create_stream {
1078            return Err(anyhow!(
1079                "stream {} not found, set `allow_create_stream` to true to create a stream",
1080                stream_str
1081            )
1082            .into());
1083        }
1084
1085        let mut config = jetstream::stream::Config {
1086            name: stream_str.clone(),
1087            max_bytes: 1000000,
1088            subjects,
1089            ..Default::default()
1090        };
1091        if let Some(v) = self.max_bytes {
1092            config.max_bytes = v;
1093        }
1094        if let Some(v) = self.max_messages {
1095            config.max_messages = v;
1096        }
1097        if let Some(v) = self.max_messages_per_subject {
1098            config.max_messages_per_subject = v;
1099        }
1100        if let Some(v) = self.max_consumers {
1101            config.max_consumers = v;
1102        }
1103        if let Some(v) = self.max_message_size {
1104            config.max_message_size = v;
1105        }
1106        tracing::info!(
1107            "create nats stream ({:?}) with config {:?}",
1108            &stream_str,
1109            config
1110        );
1111        let stream = jetstream.get_or_create_stream(config).await?;
1112        Ok(stream)
1113    }
1114
1115    pub(crate) fn create_credential(&self, seed: &str, jwt: &str) -> ConnectorResult<String> {
1116        let creds = format!(
1117            "-----BEGIN NATS USER JWT-----\n{}\n------END NATS USER JWT------\n\n\
1118                         ************************* IMPORTANT *************************\n\
1119                         NKEY Seed printed below can be used to sign and prove identity.\n\
1120                         NKEYs are sensitive and should be treated as secrets.\n\n\
1121                         -----BEGIN USER NKEY SEED-----\n{}\n------END USER NKEY SEED------\n\n\
1122                         *************************************************************",
1123            jwt, seed
1124        );
1125        Ok(creds)
1126    }
1127}
1128
1129pub(crate) fn load_certs(
1130    certificates: &str,
1131) -> ConnectorResult<Vec<rustls_pki_types::CertificateDer<'static>>> {
1132    let cert_bytes = if let Some(path) = certificates.strip_prefix("fs://") {
1133        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
1134    } else {
1135        certificates.as_bytes().to_owned()
1136    };
1137
1138    rustls_pemfile::certs(&mut cert_bytes.as_slice())
1139        .map(|cert| Ok(cert?))
1140        .collect()
1141}
1142
1143pub(crate) fn load_private_key(
1144    certificate: &str,
1145) -> ConnectorResult<rustls_pki_types::PrivateKeyDer<'static>> {
1146    let cert_bytes = if let Some(path) = certificate.strip_prefix("fs://") {
1147        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
1148    } else {
1149        certificate.as_bytes().to_owned()
1150    };
1151
1152    let cert = rustls_pemfile::pkcs8_private_keys(&mut cert_bytes.as_slice())
1153        .next()
1154        .ok_or_else(|| anyhow!("No private key found"))?;
1155    Ok(cert?.into())
1156}
1157
1158#[serde_as]
1159#[derive(Deserialize, Debug, Clone, WithOptions)]
1160pub struct MongodbCommon {
1161    /// The URL of `MongoDB`
1162    #[serde(rename = "mongodb.url")]
1163    pub connect_uri: String,
1164    /// The collection name where data should be written to or read from. For sinks, the format is
1165    /// `db_name.collection_name`. Data can also be written to dynamic collections, see `collection.name.field`
1166    /// for more information.
1167    #[serde(rename = "collection.name")]
1168    pub collection_name: String,
1169}
1170
1171impl EnforceSecret for MongodbCommon {
1172    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
1173        "mongodb.url"
1174    };
1175}
1176
1177impl MongodbCommon {
1178    pub(crate) async fn build_client(&self) -> ConnectorResult<mongodb::Client> {
1179        let client = mongodb::Client::with_uri_str(&self.connect_uri).await?;
1180
1181        Ok(client)
1182    }
1183}