risingwave_connector/connector_common/
common.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::hash::Hash;
17use std::io::Write;
18use std::time::Duration;
19
20use anyhow::{Context, anyhow};
21use async_nats::jetstream::consumer::DeliverPolicy;
22use async_nats::jetstream::{self};
23use aws_sdk_kinesis::Client as KinesisClient;
24use pulsar::authentication::oauth2::{OAuth2Authentication, OAuth2Params};
25use pulsar::{Authentication, Pulsar, TokioExecutor};
26use rdkafka::ClientConfig;
27use risingwave_common::bail;
28use serde_derive::Deserialize;
29use serde_with::json::JsonString;
30use serde_with::{DisplayFromStr, serde_as};
31use tempfile::NamedTempFile;
32use time::OffsetDateTime;
33use url::Url;
34use with_options::WithOptions;
35
36use crate::aws_utils::load_file_descriptor_from_s3;
37use crate::deserialize_duration_from_string;
38use crate::error::ConnectorResult;
39use crate::sink::SinkError;
40use crate::source::nats::source::NatsOffset;
41
42pub const PRIVATE_LINK_BROKER_REWRITE_MAP_KEY: &str = "broker.rewrite.endpoints";
43pub const PRIVATE_LINK_TARGETS_KEY: &str = "privatelink.targets";
44
45const AWS_MSK_IAM_AUTH: &str = "AWS_MSK_IAM";
46
47/// The environment variable to disable using default credential from environment.
48/// It's recommended to set this variable to `false` in cloud hosting environment.
49const DISABLE_DEFAULT_CREDENTIAL: &str = "DISABLE_DEFAULT_CREDENTIAL";
50
51#[derive(Debug, Clone, Deserialize)]
52pub struct AwsPrivateLinkItem {
53    pub az_id: Option<String>,
54    pub port: u16,
55}
56
57use aws_config::default_provider::region::DefaultRegionChain;
58use aws_config::sts::AssumeRoleProvider;
59use aws_credential_types::provider::SharedCredentialsProvider;
60use aws_types::SdkConfig;
61use aws_types::region::Region;
62use risingwave_common::util::env_var::env_var_is_true;
63
64/// A flatten config map for aws auth.
65#[derive(Deserialize, Debug, Clone, WithOptions, PartialEq)]
66pub struct AwsAuthProps {
67    #[serde(rename = "aws.region", alias = "region", alias = "s3.region")]
68    pub region: Option<String>,
69
70    #[serde(
71        rename = "aws.endpoint_url",
72        alias = "endpoint_url",
73        alias = "endpoint",
74        alias = "s3.endpoint"
75    )]
76    pub endpoint: Option<String>,
77    #[serde(
78        rename = "aws.credentials.access_key_id",
79        alias = "access_key",
80        alias = "s3.access.key"
81    )]
82    pub access_key: Option<String>,
83    #[serde(
84        rename = "aws.credentials.secret_access_key",
85        alias = "secret_key",
86        alias = "s3.secret.key"
87    )]
88    pub secret_key: Option<String>,
89    #[serde(rename = "aws.credentials.session_token", alias = "session_token")]
90    pub session_token: Option<String>,
91    /// IAM role
92    #[serde(rename = "aws.credentials.role.arn", alias = "arn")]
93    pub arn: Option<String>,
94    /// external ID in IAM role trust policy
95    #[serde(rename = "aws.credentials.role.external_id", alias = "external_id")]
96    pub external_id: Option<String>,
97    #[serde(rename = "aws.profile", alias = "profile")]
98    pub profile: Option<String>,
99    #[serde(rename = "aws.msk.signer_timeout_sec")]
100    pub msk_signer_timeout_sec: Option<u64>,
101}
102
103impl AwsAuthProps {
104    async fn build_region(&self) -> ConnectorResult<Region> {
105        if let Some(region_name) = &self.region {
106            Ok(Region::new(region_name.clone()))
107        } else {
108            let mut region_chain = DefaultRegionChain::builder();
109            if let Some(profile_name) = &self.profile {
110                region_chain = region_chain.profile_name(profile_name);
111            }
112
113            Ok(region_chain
114                .build()
115                .region()
116                .await
117                .context("region should be provided")?)
118        }
119    }
120
121    async fn build_credential_provider(&self) -> ConnectorResult<SharedCredentialsProvider> {
122        if self.access_key.is_some() && self.secret_key.is_some() {
123            Ok(SharedCredentialsProvider::new(
124                aws_credential_types::Credentials::from_keys(
125                    self.access_key.as_ref().unwrap(),
126                    self.secret_key.as_ref().unwrap(),
127                    self.session_token.clone(),
128                ),
129            ))
130        } else if !env_var_is_true(DISABLE_DEFAULT_CREDENTIAL) {
131            Ok(SharedCredentialsProvider::new(
132                aws_config::default_provider::credentials::default_provider().await,
133            ))
134        } else {
135            bail!("Both \"access_key\" and \"secret_key\" are required.")
136        }
137    }
138
139    async fn with_role_provider(
140        &self,
141        credential: SharedCredentialsProvider,
142    ) -> ConnectorResult<SharedCredentialsProvider> {
143        if let Some(role_name) = &self.arn {
144            let region = self.build_region().await?;
145            let mut role = AssumeRoleProvider::builder(role_name)
146                .session_name("RisingWave")
147                .region(region);
148            if let Some(id) = &self.external_id {
149                role = role.external_id(id);
150            }
151            let provider = role.build_from_provider(credential).await;
152            Ok(SharedCredentialsProvider::new(provider))
153        } else {
154            Ok(credential)
155        }
156    }
157
158    pub async fn build_config(&self) -> ConnectorResult<SdkConfig> {
159        let region = self.build_region().await?;
160        let credentials_provider = self
161            .with_role_provider(self.build_credential_provider().await?)
162            .await?;
163        let mut config_loader = aws_config::from_env()
164            .region(region)
165            .credentials_provider(credentials_provider);
166
167        if let Some(endpoint) = self.endpoint.as_ref() {
168            config_loader = config_loader.endpoint_url(endpoint);
169        }
170
171        Ok(config_loader.load().await)
172    }
173}
174
175#[serde_as]
176#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
177pub struct KafkaConnectionProps {
178    #[serde(rename = "properties.bootstrap.server", alias = "kafka.brokers")]
179    pub brokers: String,
180
181    /// Security protocol used for RisingWave to communicate with Kafka brokers. Could be
182    /// PLAINTEXT, SSL, SASL_PLAINTEXT or SASL_SSL.
183    #[serde(rename = "properties.security.protocol")]
184    security_protocol: Option<String>,
185
186    #[serde(rename = "properties.ssl.endpoint.identification.algorithm")]
187    ssl_endpoint_identification_algorithm: Option<String>,
188
189    // For the properties below, please refer to [librdkafka](https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md) for more information.
190    /// Path to CA certificate file for verifying the broker's key.
191    #[serde(rename = "properties.ssl.ca.location")]
192    ssl_ca_location: Option<String>,
193
194    /// CA certificate string (PEM format) for verifying the broker's key.
195    #[serde(rename = "properties.ssl.ca.pem")]
196    ssl_ca_pem: Option<String>,
197
198    /// Path to client's certificate file (PEM).
199    #[serde(rename = "properties.ssl.certificate.location")]
200    ssl_certificate_location: Option<String>,
201
202    /// Client's public key string (PEM format) used for authentication.
203    #[serde(rename = "properties.ssl.certificate.pem")]
204    ssl_certificate_pem: Option<String>,
205
206    /// Path to client's private key file (PEM).
207    #[serde(rename = "properties.ssl.key.location")]
208    ssl_key_location: Option<String>,
209
210    /// Client's private key string (PEM format) used for authentication.
211    #[serde(rename = "properties.ssl.key.pem")]
212    ssl_key_pem: Option<String>,
213
214    /// Passphrase of client's private key.
215    #[serde(rename = "properties.ssl.key.password")]
216    ssl_key_password: Option<String>,
217
218    /// SASL mechanism if SASL is enabled. Currently support PLAIN, SCRAM, GSSAPI, and AWS_MSK_IAM.
219    #[serde(rename = "properties.sasl.mechanism")]
220    sasl_mechanism: Option<String>,
221
222    /// SASL username for SASL/PLAIN and SASL/SCRAM.
223    #[serde(rename = "properties.sasl.username")]
224    sasl_username: Option<String>,
225
226    /// SASL password for SASL/PLAIN and SASL/SCRAM.
227    #[serde(rename = "properties.sasl.password")]
228    sasl_password: Option<String>,
229
230    /// Kafka server's Kerberos principal name under SASL/GSSAPI, not including /hostname@REALM.
231    #[serde(rename = "properties.sasl.kerberos.service.name")]
232    sasl_kerberos_service_name: Option<String>,
233
234    /// Path to client's Kerberos keytab file under SASL/GSSAPI.
235    #[serde(rename = "properties.sasl.kerberos.keytab")]
236    sasl_kerberos_keytab: Option<String>,
237
238    /// Client's Kerberos principal name under SASL/GSSAPI.
239    #[serde(rename = "properties.sasl.kerberos.principal")]
240    sasl_kerberos_principal: Option<String>,
241
242    /// Shell command to refresh or acquire the client's Kerberos ticket under SASL/GSSAPI.
243    #[serde(rename = "properties.sasl.kerberos.kinit.cmd")]
244    sasl_kerberos_kinit_cmd: Option<String>,
245
246    /// Minimum time in milliseconds between key refresh attempts under SASL/GSSAPI.
247    #[serde(rename = "properties.sasl.kerberos.min.time.before.relogin")]
248    sasl_kerberos_min_time_before_relogin: Option<String>,
249
250    /// Configurations for SASL/OAUTHBEARER.
251    #[serde(rename = "properties.sasl.oauthbearer.config")]
252    sasl_oathbearer_config: Option<String>,
253}
254
255#[serde_as]
256#[derive(Debug, Clone, Deserialize, WithOptions)]
257pub struct KafkaCommon {
258    // connection related props are moved to `KafkaConnection`
259    #[serde(rename = "topic", alias = "kafka.topic")]
260    pub topic: String,
261
262    #[serde(
263        rename = "properties.sync.call.timeout",
264        deserialize_with = "deserialize_duration_from_string",
265        default = "default_kafka_sync_call_timeout"
266    )]
267    pub sync_call_timeout: Duration,
268}
269
270#[serde_as]
271#[derive(Debug, Clone, Deserialize, WithOptions, PartialEq, Hash, Eq)]
272pub struct KafkaPrivateLinkCommon {
273    /// This is generated from `private_link_targets` and `private_link_endpoint` in frontend, instead of given by users.
274    #[serde(rename = "broker.rewrite.endpoints")]
275    #[serde_as(as = "Option<JsonString>")]
276    pub broker_rewrite_map: Option<BTreeMap<String, String>>,
277}
278
279const fn default_kafka_sync_call_timeout() -> Duration {
280    Duration::from_secs(5)
281}
282
283const fn default_socket_keepalive_enable() -> bool {
284    true
285}
286
287#[serde_as]
288#[derive(Debug, Clone, Deserialize, WithOptions)]
289pub struct RdKafkaPropertiesCommon {
290    /// Maximum Kafka protocol request message size. Due to differing framing overhead between
291    /// protocol versions the producer is unable to reliably enforce a strict max message limit at
292    /// produce time and may exceed the maximum size by one message in protocol ProduceRequests,
293    /// the broker will enforce the topic's max.message.bytes limit
294    #[serde(rename = "properties.message.max.bytes")]
295    #[serde_as(as = "Option<DisplayFromStr>")]
296    pub message_max_bytes: Option<usize>,
297
298    /// Maximum Kafka protocol response message size. This serves as a safety precaution to avoid
299    /// memory exhaustion in case of protocol hickups. This value must be at least fetch.max.bytes
300    /// + 512 to allow for protocol overhead; the value is adjusted automatically unless the
301    /// configuration property is explicitly set.
302    #[serde(rename = "properties.receive.message.max.bytes")]
303    #[serde_as(as = "Option<DisplayFromStr>")]
304    pub receive_message_max_bytes: Option<usize>,
305
306    #[serde(rename = "properties.statistics.interval.ms")]
307    #[serde_as(as = "Option<DisplayFromStr>")]
308    pub statistics_interval_ms: Option<usize>,
309
310    /// Client identifier
311    #[serde(rename = "properties.client.id")]
312    #[serde_as(as = "Option<DisplayFromStr>")]
313    pub client_id: Option<String>,
314
315    #[serde(rename = "properties.enable.ssl.certificate.verification")]
316    #[serde_as(as = "Option<DisplayFromStr>")]
317    pub enable_ssl_certificate_verification: Option<bool>,
318
319    #[serde(
320        rename = "properties.socket.keepalive.enable",
321        default = "default_socket_keepalive_enable"
322    )]
323    #[serde_as(as = "DisplayFromStr")]
324    pub socket_keepalive_enable: bool,
325}
326
327impl RdKafkaPropertiesCommon {
328    pub(crate) fn set_client(&self, c: &mut rdkafka::ClientConfig) {
329        if let Some(v) = self.statistics_interval_ms {
330            c.set("statistics.interval.ms", v.to_string());
331        }
332        if let Some(v) = self.message_max_bytes {
333            c.set("message.max.bytes", v.to_string());
334        }
335        if let Some(v) = self.receive_message_max_bytes {
336            c.set("receive.message.max.bytes", v.to_string());
337        }
338        if let Some(v) = self.client_id.as_ref() {
339            c.set("client.id", v);
340        }
341        if let Some(v) = self.enable_ssl_certificate_verification {
342            c.set("enable.ssl.certificate.verification", v.to_string());
343        }
344        c.set(
345            "socket.keepalive.enable",
346            self.socket_keepalive_enable.to_string(),
347        );
348    }
349}
350
351impl KafkaConnectionProps {
352    #[cfg(test)]
353    pub fn test_default() -> Self {
354        Self {
355            brokers: "localhost:9092".to_owned(),
356            security_protocol: None,
357            ssl_ca_location: None,
358            ssl_certificate_location: None,
359            ssl_key_location: None,
360            ssl_ca_pem: None,
361            ssl_certificate_pem: None,
362            ssl_key_pem: None,
363            ssl_key_password: None,
364            ssl_endpoint_identification_algorithm: None,
365            sasl_mechanism: None,
366            sasl_username: None,
367            sasl_password: None,
368            sasl_kerberos_service_name: None,
369            sasl_kerberos_keytab: None,
370            sasl_kerberos_principal: None,
371            sasl_kerberos_kinit_cmd: None,
372            sasl_kerberos_min_time_before_relogin: None,
373            sasl_oathbearer_config: None,
374        }
375    }
376
377    pub(crate) fn set_security_properties(&self, config: &mut ClientConfig) {
378        // AWS_MSK_IAM
379        if self.is_aws_msk_iam() {
380            config.set("security.protocol", "SASL_SSL");
381            config.set("sasl.mechanism", "OAUTHBEARER");
382            return;
383        }
384
385        // Security protocol
386        if let Some(security_protocol) = self.security_protocol.as_ref() {
387            config.set("security.protocol", security_protocol);
388        }
389
390        // SSL
391        if let Some(ssl_ca_location) = self.ssl_ca_location.as_ref() {
392            config.set("ssl.ca.location", ssl_ca_location);
393        }
394        if let Some(ssl_ca_pem) = self.ssl_ca_pem.as_ref() {
395            config.set("ssl.ca.pem", ssl_ca_pem);
396        }
397        if let Some(ssl_certificate_location) = self.ssl_certificate_location.as_ref() {
398            config.set("ssl.certificate.location", ssl_certificate_location);
399        }
400        if let Some(ssl_certificate_pem) = self.ssl_certificate_pem.as_ref() {
401            config.set("ssl.certificate.pem", ssl_certificate_pem);
402        }
403        if let Some(ssl_key_location) = self.ssl_key_location.as_ref() {
404            config.set("ssl.key.location", ssl_key_location);
405        }
406        if let Some(ssl_key_pem) = self.ssl_key_pem.as_ref() {
407            config.set("ssl.key.pem", ssl_key_pem);
408        }
409        if let Some(ssl_key_password) = self.ssl_key_password.as_ref() {
410            config.set("ssl.key.password", ssl_key_password);
411        }
412        if let Some(ssl_endpoint_identification_algorithm) =
413            self.ssl_endpoint_identification_algorithm.as_ref()
414        {
415            // accept only `none` and `http` here, let the sdk do the check
416            config.set(
417                "ssl.endpoint.identification.algorithm",
418                ssl_endpoint_identification_algorithm,
419            );
420        }
421
422        // SASL mechanism
423        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref() {
424            config.set("sasl.mechanism", sasl_mechanism);
425        }
426
427        // SASL/PLAIN & SASL/SCRAM
428        if let Some(sasl_username) = self.sasl_username.as_ref() {
429            config.set("sasl.username", sasl_username);
430        }
431        if let Some(sasl_password) = self.sasl_password.as_ref() {
432            config.set("sasl.password", sasl_password);
433        }
434
435        // SASL/GSSAPI
436        if let Some(sasl_kerberos_service_name) = self.sasl_kerberos_service_name.as_ref() {
437            config.set("sasl.kerberos.service.name", sasl_kerberos_service_name);
438        }
439        if let Some(sasl_kerberos_keytab) = self.sasl_kerberos_keytab.as_ref() {
440            config.set("sasl.kerberos.keytab", sasl_kerberos_keytab);
441        }
442        if let Some(sasl_kerberos_principal) = self.sasl_kerberos_principal.as_ref() {
443            config.set("sasl.kerberos.principal", sasl_kerberos_principal);
444        }
445        if let Some(sasl_kerberos_kinit_cmd) = self.sasl_kerberos_kinit_cmd.as_ref() {
446            config.set("sasl.kerberos.kinit.cmd", sasl_kerberos_kinit_cmd);
447        }
448        if let Some(sasl_kerberos_min_time_before_relogin) =
449            self.sasl_kerberos_min_time_before_relogin.as_ref()
450        {
451            config.set(
452                "sasl.kerberos.min.time.before.relogin",
453                sasl_kerberos_min_time_before_relogin,
454            );
455        }
456
457        // SASL/OAUTHBEARER
458        if let Some(sasl_oathbearer_config) = self.sasl_oathbearer_config.as_ref() {
459            config.set("sasl.oauthbearer.config", sasl_oathbearer_config);
460        }
461        // Currently, we only support unsecured OAUTH.
462        config.set("enable.sasl.oauthbearer.unsecure.jwt", "true");
463    }
464
465    pub(crate) fn is_aws_msk_iam(&self) -> bool {
466        if let Some(sasl_mechanism) = self.sasl_mechanism.as_ref()
467            && sasl_mechanism == AWS_MSK_IAM_AUTH
468        {
469            true
470        } else {
471            false
472        }
473    }
474}
475
476#[derive(Clone, Debug, Deserialize, WithOptions)]
477pub struct PulsarCommon {
478    #[serde(rename = "topic", alias = "pulsar.topic")]
479    pub topic: String,
480
481    #[serde(rename = "service.url", alias = "pulsar.service.url")]
482    pub service_url: String,
483
484    #[serde(rename = "auth.token")]
485    pub auth_token: Option<String>,
486}
487
488#[derive(Clone, Debug, Deserialize, WithOptions)]
489pub struct PulsarOauthCommon {
490    #[serde(rename = "oauth.issuer.url")]
491    pub issuer_url: String,
492
493    #[serde(rename = "oauth.credentials.url")]
494    pub credentials_url: String,
495
496    #[serde(rename = "oauth.audience")]
497    pub audience: String,
498
499    #[serde(rename = "oauth.scope")]
500    pub scope: Option<String>,
501}
502
503fn create_credential_temp_file(credentials: &[u8]) -> std::io::Result<NamedTempFile> {
504    let mut f = NamedTempFile::new()?;
505    f.write_all(credentials)?;
506    f.as_file().sync_all()?;
507    Ok(f)
508}
509
510impl PulsarCommon {
511    pub(crate) async fn build_client(
512        &self,
513        oauth: &Option<PulsarOauthCommon>,
514        aws_auth_props: &AwsAuthProps,
515    ) -> ConnectorResult<Pulsar<TokioExecutor>> {
516        let mut pulsar_builder = Pulsar::builder(&self.service_url, TokioExecutor);
517        let mut temp_file = None;
518        if let Some(oauth) = oauth.as_ref() {
519            let url = Url::parse(&oauth.credentials_url)?;
520            match url.scheme() {
521                "s3" => {
522                    let credentials = load_file_descriptor_from_s3(&url, aws_auth_props).await?;
523                    temp_file = Some(
524                        create_credential_temp_file(&credentials)
525                            .context("failed to create temp file for pulsar credentials")?,
526                    );
527                }
528                "file" => {}
529                _ => {
530                    bail!("invalid credentials_url, only file url and s3 url are supported",);
531                }
532            }
533
534            let auth_params = OAuth2Params {
535                issuer_url: oauth.issuer_url.clone(),
536                credentials_url: if temp_file.is_none() {
537                    oauth.credentials_url.clone()
538                } else {
539                    let mut raw_path = temp_file
540                        .as_ref()
541                        .unwrap()
542                        .path()
543                        .to_str()
544                        .unwrap()
545                        .to_owned();
546                    raw_path.insert_str(0, "file://");
547                    raw_path
548                },
549                audience: Some(oauth.audience.clone()),
550                scope: oauth.scope.clone(),
551            };
552
553            pulsar_builder = pulsar_builder
554                .with_auth_provider(OAuth2Authentication::client_credentials(auth_params));
555        } else if let Some(auth_token) = &self.auth_token {
556            pulsar_builder = pulsar_builder.with_auth(Authentication {
557                name: "token".to_owned(),
558                data: Vec::from(auth_token.as_str()),
559            });
560        }
561
562        let res = pulsar_builder.build().await.map_err(|e| anyhow!(e))?;
563        drop(temp_file);
564        Ok(res)
565    }
566}
567
568#[derive(Deserialize, Debug, Clone, WithOptions)]
569pub struct KinesisCommon {
570    #[serde(rename = "stream", alias = "kinesis.stream.name")]
571    pub stream_name: String,
572    #[serde(rename = "aws.region", alias = "kinesis.stream.region")]
573    pub stream_region: String,
574    #[serde(rename = "endpoint", alias = "kinesis.endpoint")]
575    pub endpoint: Option<String>,
576    #[serde(
577        rename = "aws.credentials.access_key_id",
578        alias = "kinesis.credentials.access"
579    )]
580    pub credentials_access_key: Option<String>,
581    #[serde(
582        rename = "aws.credentials.secret_access_key",
583        alias = "kinesis.credentials.secret"
584    )]
585    pub credentials_secret_access_key: Option<String>,
586    #[serde(
587        rename = "aws.credentials.session_token",
588        alias = "kinesis.credentials.session_token"
589    )]
590    pub session_token: Option<String>,
591    #[serde(rename = "aws.credentials.role.arn", alias = "kinesis.assumerole.arn")]
592    pub assume_role_arn: Option<String>,
593    #[serde(
594        rename = "aws.credentials.role.external_id",
595        alias = "kinesis.assumerole.external_id"
596    )]
597    pub assume_role_external_id: Option<String>,
598}
599
600impl KinesisCommon {
601    pub(crate) async fn build_client(&self) -> ConnectorResult<KinesisClient> {
602        let config = AwsAuthProps {
603            region: Some(self.stream_region.clone()),
604            endpoint: self.endpoint.clone(),
605            access_key: self.credentials_access_key.clone(),
606            secret_key: self.credentials_secret_access_key.clone(),
607            session_token: self.session_token.clone(),
608            arn: self.assume_role_arn.clone(),
609            external_id: self.assume_role_external_id.clone(),
610            profile: Default::default(),
611            msk_signer_timeout_sec: Default::default(),
612        };
613        let aws_config = config.build_config().await?;
614        let mut builder = aws_sdk_kinesis::config::Builder::from(&aws_config);
615        if let Some(endpoint) = &config.endpoint {
616            builder = builder.endpoint_url(endpoint);
617        }
618        Ok(KinesisClient::from_conf(builder.build()))
619    }
620}
621
622#[serde_as]
623#[derive(Deserialize, Debug, Clone, WithOptions)]
624pub struct NatsCommon {
625    #[serde(rename = "server_url")]
626    pub server_url: String,
627    #[serde(rename = "subject")]
628    pub subject: String,
629    #[serde(rename = "connect_mode")]
630    pub connect_mode: String,
631    #[serde(rename = "username")]
632    pub user: Option<String>,
633    #[serde(rename = "password")]
634    pub password: Option<String>,
635    #[serde(rename = "jwt")]
636    pub jwt: Option<String>,
637    #[serde(rename = "nkey")]
638    pub nkey: Option<String>,
639    #[serde(rename = "max_bytes")]
640    #[serde_as(as = "Option<DisplayFromStr>")]
641    pub max_bytes: Option<i64>,
642    #[serde(rename = "max_messages")]
643    #[serde_as(as = "Option<DisplayFromStr>")]
644    pub max_messages: Option<i64>,
645    #[serde(rename = "max_messages_per_subject")]
646    #[serde_as(as = "Option<DisplayFromStr>")]
647    pub max_messages_per_subject: Option<i64>,
648    #[serde(rename = "max_consumers")]
649    #[serde_as(as = "Option<DisplayFromStr>")]
650    pub max_consumers: Option<i32>,
651    #[serde(rename = "max_message_size")]
652    #[serde_as(as = "Option<DisplayFromStr>")]
653    pub max_message_size: Option<i32>,
654}
655
656impl NatsCommon {
657    pub(crate) async fn build_client(&self) -> ConnectorResult<async_nats::Client> {
658        let mut connect_options = async_nats::ConnectOptions::new();
659        match self.connect_mode.as_str() {
660            "user_and_password" => {
661                if let (Some(v_user), Some(v_password)) =
662                    (self.user.as_ref(), self.password.as_ref())
663                {
664                    connect_options =
665                        connect_options.user_and_password(v_user.into(), v_password.into())
666                } else {
667                    bail!("nats connect mode is user_and_password, but user or password is empty");
668                }
669            }
670
671            "credential" => {
672                if let (Some(v_nkey), Some(v_jwt)) = (self.nkey.as_ref(), self.jwt.as_ref()) {
673                    connect_options = connect_options
674                        .credentials(&self.create_credential(v_nkey, v_jwt)?)
675                        .expect("failed to parse static creds")
676                } else {
677                    bail!("nats connect mode is credential, but nkey or jwt is empty");
678                }
679            }
680            "plain" => {}
681            _ => {
682                bail!("nats connect mode only accepts user_and_password/credential/plain");
683            }
684        };
685
686        let servers = self.server_url.split(',').collect::<Vec<&str>>();
687        let client = connect_options
688            .connect(
689                servers
690                    .iter()
691                    .map(|url| url.parse())
692                    .collect::<Result<Vec<async_nats::ServerAddr>, _>>()?,
693            )
694            .await
695            .context("build nats client error")
696            .map_err(SinkError::Nats)?;
697        Ok(client)
698    }
699
700    pub(crate) async fn build_context(&self) -> ConnectorResult<jetstream::Context> {
701        let client = self.build_client().await?;
702        let jetstream = async_nats::jetstream::new(client);
703        Ok(jetstream)
704    }
705
706    pub(crate) async fn build_consumer(
707        &self,
708        stream: String,
709        durable_consumer_name: String,
710        split_id: String,
711        start_sequence: NatsOffset,
712        mut config: jetstream::consumer::pull::Config,
713    ) -> ConnectorResult<
714        async_nats::jetstream::consumer::Consumer<async_nats::jetstream::consumer::pull::Config>,
715    > {
716        let context = self.build_context().await?;
717        let stream = self.build_or_get_stream(context.clone(), stream).await?;
718        let subject_name = self
719            .subject
720            .replace(',', "-")
721            .replace(['.', '>', '*', ' ', '\t'], "_");
722        let name = format!("risingwave-consumer-{}-{}", subject_name, split_id);
723
724        let deliver_policy = match start_sequence {
725            NatsOffset::Earliest => DeliverPolicy::All,
726            NatsOffset::Latest => DeliverPolicy::New,
727            NatsOffset::SequenceNumber(v) => {
728                // for compatibility, we do not write to any state table now
729                let parsed = v
730                    .parse::<u64>()
731                    .context("failed to parse nats offset as sequence number")?;
732                DeliverPolicy::ByStartSequence {
733                    start_sequence: 1 + parsed,
734                }
735            }
736            NatsOffset::Timestamp(v) => DeliverPolicy::ByStartTime {
737                start_time: OffsetDateTime::from_unix_timestamp_nanos(v as i128 * 1_000_000)
738                    .context("invalid timestamp for nats offset")?,
739            },
740            NatsOffset::None => DeliverPolicy::All,
741        };
742
743        let consumer = match stream.get_consumer(&name).await {
744            Ok(consumer) => consumer,
745            _ => {
746                stream
747                    .get_or_create_consumer(&name, {
748                        config.deliver_policy = deliver_policy;
749                        config.durable_name = Some(durable_consumer_name);
750                        config.filter_subjects =
751                            self.subject.split(',').map(|s| s.to_owned()).collect();
752                        config
753                    })
754                    .await?
755            }
756        };
757        Ok(consumer)
758    }
759
760    pub(crate) async fn build_or_get_stream(
761        &self,
762        jetstream: jetstream::Context,
763        stream: String,
764    ) -> ConnectorResult<jetstream::stream::Stream> {
765        let subjects: Vec<String> = self.subject.split(',').map(|s| s.to_owned()).collect();
766        if let Ok(mut stream_instance) = jetstream.get_stream(&stream).await {
767            tracing::info!(
768                "load existing nats stream ({:?}) with config {:?}",
769                stream,
770                stream_instance.info().await?
771            );
772            return Ok(stream_instance);
773        }
774
775        let mut config = jetstream::stream::Config {
776            name: stream.clone(),
777            max_bytes: 1000000,
778            subjects,
779            ..Default::default()
780        };
781        if let Some(v) = self.max_bytes {
782            config.max_bytes = v;
783        }
784        if let Some(v) = self.max_messages {
785            config.max_messages = v;
786        }
787        if let Some(v) = self.max_messages_per_subject {
788            config.max_messages_per_subject = v;
789        }
790        if let Some(v) = self.max_consumers {
791            config.max_consumers = v;
792        }
793        if let Some(v) = self.max_message_size {
794            config.max_message_size = v;
795        }
796        tracing::info!(
797            "create nats stream ({:?}) with config {:?}",
798            &stream,
799            config
800        );
801        let stream = jetstream.get_or_create_stream(config).await?;
802        Ok(stream)
803    }
804
805    pub(crate) fn create_credential(&self, seed: &str, jwt: &str) -> ConnectorResult<String> {
806        let creds = format!(
807            "-----BEGIN NATS USER JWT-----\n{}\n------END NATS USER JWT------\n\n\
808                         ************************* IMPORTANT *************************\n\
809                         NKEY Seed printed below can be used to sign and prove identity.\n\
810                         NKEYs are sensitive and should be treated as secrets.\n\n\
811                         -----BEGIN USER NKEY SEED-----\n{}\n------END USER NKEY SEED------\n\n\
812                         *************************************************************",
813            jwt, seed
814        );
815        Ok(creds)
816    }
817}
818
819pub(crate) fn load_certs(
820    certificates: &str,
821) -> ConnectorResult<Vec<rustls_pki_types::CertificateDer<'static>>> {
822    let cert_bytes = if let Some(path) = certificates.strip_prefix("fs://") {
823        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
824    } else {
825        certificates.as_bytes().to_owned()
826    };
827
828    rustls_pemfile::certs(&mut cert_bytes.as_slice())
829        .map(|cert| Ok(cert?))
830        .collect()
831}
832
833pub(crate) fn load_private_key(
834    certificate: &str,
835) -> ConnectorResult<rustls_pki_types::PrivateKeyDer<'static>> {
836    let cert_bytes = if let Some(path) = certificate.strip_prefix("fs://") {
837        std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
838    } else {
839        certificate.as_bytes().to_owned()
840    };
841
842    let cert = rustls_pemfile::pkcs8_private_keys(&mut cert_bytes.as_slice())
843        .next()
844        .ok_or_else(|| anyhow!("No private key found"))?;
845    Ok(cert?.into())
846}
847
848#[serde_as]
849#[derive(Deserialize, Debug, Clone, WithOptions)]
850pub struct MongodbCommon {
851    /// The URL of MongoDB
852    #[serde(rename = "mongodb.url")]
853    pub connect_uri: String,
854    /// The collection name where data should be written to or read from. For sinks, the format is
855    /// `db_name.collection_name`. Data can also be written to dynamic collections, see `collection.name.field`
856    /// for more information.
857    #[serde(rename = "collection.name")]
858    pub collection_name: String,
859}
860
861impl MongodbCommon {
862    pub(crate) async fn build_client(&self) -> ConnectorResult<mongodb::Client> {
863        let client = mongodb::Client::with_uri_str(&self.connect_uri).await?;
864
865        Ok(client)
866    }
867}