risingwave_connector/source/kafka/
client_context.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use std::thread;
18
19use anyhow::anyhow;
20use aws_config::Region;
21use aws_sdk_s3::config::SharedCredentialsProvider;
22use rdkafka::client::{BrokerAddr, OAuthToken};
23use rdkafka::consumer::ConsumerContext;
24use rdkafka::message::DeliveryResult;
25use rdkafka::producer::ProducerContext;
26use rdkafka::{ClientContext, Statistics};
27
28use super::private_link::{BrokerAddrRewriter, PrivateLinkContextRole};
29use super::stats::RdKafkaStats;
30use crate::connector_common::AwsAuthProps;
31use crate::error::ConnectorResult;
32
33struct IamAuthEnv {
34    credentials_provider: SharedCredentialsProvider,
35    region: Region,
36    // XXX(runji): madsim does not support `Handle` for now
37    #[cfg(not(madsim))]
38    rt: tokio::runtime::Handle,
39    signer_timeout_sec: u64,
40}
41
42pub struct KafkaContextCommon {
43    // For VPC PrivateLink support
44    addr_rewriter: BrokerAddrRewriter,
45
46    // identifier is required when reporting metrics as a label, usually it is compose by connector
47    // format (source or sink) and corresponding id (source_id or sink_id)
48    // identifier and metrics should be set at the same time
49    identifier: Option<String>,
50    metrics: Option<Arc<RdKafkaStats>>,
51
52    /// Credential and region for AWS MSK
53    auth: Option<IamAuthEnv>,
54}
55
56impl KafkaContextCommon {
57    pub async fn new(
58        broker_rewrite_map: Option<BTreeMap<String, String>>,
59        identifier: Option<String>,
60        metrics: Option<Arc<RdKafkaStats>>,
61        auth: AwsAuthProps,
62        is_aws_msk_iam: bool,
63    ) -> ConnectorResult<Self> {
64        let addr_rewriter =
65            BrokerAddrRewriter::new(PrivateLinkContextRole::Consumer, broker_rewrite_map)?;
66        let auth = if is_aws_msk_iam {
67            let config = auth.build_config().await?;
68            let credentials_provider = config
69                .credentials_provider()
70                .ok_or_else(|| anyhow!("missing aws credentials_provider"))?;
71            let region = config
72                .region()
73                .ok_or_else(|| anyhow!("missing aws region"))?
74                .clone();
75            Some(IamAuthEnv {
76                credentials_provider,
77                region,
78                #[cfg(not(madsim))]
79                rt: tokio::runtime::Handle::current(),
80                signer_timeout_sec: auth
81                    .msk_signer_timeout_sec
82                    .unwrap_or(Self::default_msk_signer_timeout_sec()),
83            })
84        } else {
85            None
86        };
87        Ok(Self {
88            addr_rewriter,
89            identifier,
90            metrics,
91            auth,
92        })
93    }
94
95    fn default_msk_signer_timeout_sec() -> u64 {
96        10
97    }
98}
99
100impl KafkaContextCommon {
101    fn stats(&self, statistics: Statistics) {
102        if let Some(metrics) = &self.metrics
103            && let Some(id) = &self.identifier
104        {
105            metrics.report(id.as_str(), &statistics);
106        }
107    }
108
109    fn rewrite_broker_addr(&self, addr: BrokerAddr) -> BrokerAddr {
110        self.addr_rewriter.rewrite_broker_addr(addr)
111    }
112
113    // XXX(runji): oauth is ignored in simulation
114    #[cfg_or_panic::cfg_or_panic(not(madsim))]
115    fn generate_oauth_token(
116        &self,
117        _oauthbearer_config: Option<&str>,
118    ) -> Result<OAuthToken, Box<dyn std::error::Error>> {
119        use aws_msk_iam_sasl_signer::generate_auth_token_from_credentials_provider;
120        use tokio::time::{Duration, timeout};
121
122        if let Some(IamAuthEnv {
123            credentials_provider,
124            region,
125            rt,
126            signer_timeout_sec,
127        }) = &self.auth
128        {
129            let region = region.clone();
130            let credentials_provider = credentials_provider.clone();
131            let rt = rt.clone();
132            let signer_timeout_sec = *signer_timeout_sec;
133            let (token, expiration_time_ms) = {
134                let handle = thread::spawn(move || {
135                    rt.block_on(async {
136                        timeout(
137                            Duration::from_secs(signer_timeout_sec),
138                            generate_auth_token_from_credentials_provider(
139                                region,
140                                credentials_provider,
141                            ),
142                        )
143                        .await
144                    })
145                });
146                handle.join().unwrap()??
147            };
148            Ok(OAuthToken {
149                token,
150                principal_name: "".to_owned(),
151                lifetime_ms: expiration_time_ms,
152            })
153        } else {
154            Err("must provide AWS IAM credential".into())
155        }
156    }
157
158    fn enable_refresh_oauth_token(&self) -> bool {
159        self.auth.is_some()
160    }
161}
162
163pub type BoxConsumerContext = Box<dyn ConsumerContext>;
164
165/// Kafka consumer context used for private link, IAM auth, and metrics
166pub struct RwConsumerContext {
167    common: KafkaContextCommon,
168}
169
170impl RwConsumerContext {
171    pub fn new(common: KafkaContextCommon) -> Self {
172        Self { common }
173    }
174}
175
176impl ClientContext for RwConsumerContext {
177    /// this func serves as a callback when `poll` is completed.
178    fn stats(&self, statistics: Statistics) {
179        self.common.stats(statistics);
180    }
181
182    fn rewrite_broker_addr(&self, addr: BrokerAddr) -> BrokerAddr {
183        self.common.rewrite_broker_addr(addr)
184    }
185
186    fn generate_oauth_token(
187        &self,
188        oauthbearer_config: Option<&str>,
189    ) -> Result<OAuthToken, Box<dyn std::error::Error>> {
190        self.common.generate_oauth_token(oauthbearer_config)
191    }
192
193    fn enable_refresh_oauth_token(&self) -> bool {
194        self.common.enable_refresh_oauth_token()
195    }
196}
197
198// required by the trait bound of BaseConsumer
199impl ConsumerContext for RwConsumerContext {}
200
201/// Kafka producer context used for private link, IAM auth, and metrics
202pub struct RwProducerContext {
203    common: KafkaContextCommon,
204}
205
206impl RwProducerContext {
207    pub fn new(common: KafkaContextCommon) -> Self {
208        Self { common }
209    }
210}
211
212impl ClientContext for RwProducerContext {
213    fn stats(&self, statistics: Statistics) {
214        self.common.stats(statistics);
215    }
216
217    fn rewrite_broker_addr(&self, addr: BrokerAddr) -> BrokerAddr {
218        self.common.rewrite_broker_addr(addr)
219    }
220
221    fn generate_oauth_token(
222        &self,
223        oauthbearer_config: Option<&str>,
224    ) -> Result<OAuthToken, Box<dyn std::error::Error>> {
225        self.common.generate_oauth_token(oauthbearer_config)
226    }
227
228    fn enable_refresh_oauth_token(&self) -> bool {
229        self.common.enable_refresh_oauth_token()
230    }
231}
232
233impl ProducerContext for RwProducerContext {
234    type DeliveryOpaque = ();
235
236    fn delivery(&self, _: &DeliveryResult<'_>, _: Self::DeliveryOpaque) {}
237}