risingwave_connector/source/kafka/
client_context.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::sync::{Arc, LazyLock};
17
18use anyhow::{Context, anyhow};
19use aws_config::Region;
20use aws_sdk_s3::config::SharedCredentialsProvider;
21use rdkafka::client::{BrokerAddr, OAuthToken};
22use rdkafka::consumer::ConsumerContext;
23use rdkafka::message::DeliveryResult;
24use rdkafka::producer::ProducerContext;
25use rdkafka::{ClientContext, Statistics};
26use tokio::runtime::Runtime;
27
28use super::private_link::{BrokerAddrRewriter, PrivateLinkContextRole};
29use super::stats::RdKafkaStats;
30use crate::connector_common::AwsAuthProps;
31use crate::error::ConnectorResult;
32
33struct IamAuthEnv {
34    credentials_provider: SharedCredentialsProvider,
35    region: Region,
36    signer_timeout_sec: u64,
37}
38
39pub struct KafkaContextCommon {
40    // For VPC PrivateLink support
41    addr_rewriter: BrokerAddrRewriter,
42
43    // identifier is required when reporting metrics as a label, usually it is compose by connector
44    // format (source or sink) and corresponding id (source_id or sink_id)
45    // identifier and metrics should be set at the same time
46    identifier: Option<String>,
47    metrics: Option<Arc<RdKafkaStats>>,
48
49    /// Credential and region for AWS MSK
50    auth: Option<IamAuthEnv>,
51}
52
53impl KafkaContextCommon {
54    pub async fn new(
55        broker_rewrite_map: Option<BTreeMap<String, String>>,
56        identifier: Option<String>,
57        metrics: Option<Arc<RdKafkaStats>>,
58        auth: AwsAuthProps,
59        is_aws_msk_iam: bool,
60    ) -> ConnectorResult<Self> {
61        let addr_rewriter =
62            BrokerAddrRewriter::new(PrivateLinkContextRole::Consumer, broker_rewrite_map)?;
63        let auth = if is_aws_msk_iam {
64            let config = auth.build_config().await?;
65            let credentials_provider = config
66                .credentials_provider()
67                .ok_or_else(|| anyhow!("missing aws credentials_provider"))?;
68            let region = config
69                .region()
70                .ok_or_else(|| anyhow!("missing aws region"))?
71                .clone();
72            Some(IamAuthEnv {
73                credentials_provider,
74                region,
75                signer_timeout_sec: auth
76                    .msk_signer_timeout_sec
77                    .unwrap_or(Self::default_msk_signer_timeout_sec()),
78            })
79        } else {
80            None
81        };
82        Ok(Self {
83            addr_rewriter,
84            identifier,
85            metrics,
86            auth,
87        })
88    }
89
90    fn default_msk_signer_timeout_sec() -> u64 {
91        10
92    }
93}
94
95pub static KAFKA_SOURCE_RUNTIME: LazyLock<Runtime> = LazyLock::new(|| {
96    tokio::runtime::Builder::new_multi_thread()
97        .thread_name("rw-frontend")
98        .enable_all()
99        .build()
100        .expect("failed to build frontend runtime")
101});
102
103impl KafkaContextCommon {
104    fn stats(&self, statistics: Statistics) {
105        if let Some(metrics) = &self.metrics
106            && let Some(id) = &self.identifier
107        {
108            metrics.report(id.as_str(), &statistics);
109        }
110    }
111
112    fn rewrite_broker_addr(&self, addr: BrokerAddr) -> BrokerAddr {
113        self.addr_rewriter.rewrite_broker_addr(addr)
114    }
115
116    // XXX(runji): oauth is ignored in simulation
117    #[cfg_or_panic::cfg_or_panic(not(madsim))]
118    fn generate_oauth_token(
119        &self,
120        _oauthbearer_config: Option<&str>,
121    ) -> Result<OAuthToken, Box<dyn std::error::Error>> {
122        use aws_msk_iam_sasl_signer::generate_auth_token_from_credentials_provider;
123        use tokio::time::{Duration, timeout};
124
125        if let Some(IamAuthEnv {
126            credentials_provider,
127            region,
128            signer_timeout_sec,
129            ..
130        }) = &self.auth
131        {
132            let region = region.clone();
133            let credentials_provider = credentials_provider.clone();
134            let signer_timeout_sec = *signer_timeout_sec;
135            let (token, expiration_time_ms) = {
136                let result = tokio::task::block_in_place(move || {
137                    KAFKA_SOURCE_RUNTIME.block_on(async {
138                        timeout(
139                            Duration::from_secs(signer_timeout_sec),
140                            generate_auth_token_from_credentials_provider(
141                                region,
142                                credentials_provider,
143                            ),
144                        )
145                        .await
146                    })
147                });
148                result
149                    .map_err(|_e| "generating AWS MSK IAM token timeout".to_owned())?
150                    .map_err(|e| anyhow!(e))
151                    .context("failed to generate AWS MSK IAM token")?
152            };
153            Ok(OAuthToken {
154                token,
155                principal_name: "".to_owned(),
156                lifetime_ms: expiration_time_ms,
157            })
158        } else {
159            Err("must provide AWS IAM credential".into())
160        }
161    }
162
163    fn enable_refresh_oauth_token(&self) -> bool {
164        self.auth.is_some()
165    }
166}
167
168pub type BoxConsumerContext = Box<dyn ConsumerContext>;
169
170/// Kafka consumer context used for private link, IAM auth, and metrics
171pub struct RwConsumerContext {
172    common: KafkaContextCommon,
173}
174
175impl RwConsumerContext {
176    pub fn new(common: KafkaContextCommon) -> Self {
177        Self { common }
178    }
179}
180
181impl ClientContext for RwConsumerContext {
182    /// this func serves as a callback when `poll` is completed.
183    fn stats(&self, statistics: Statistics) {
184        self.common.stats(statistics);
185    }
186
187    fn rewrite_broker_addr(&self, addr: BrokerAddr) -> BrokerAddr {
188        self.common.rewrite_broker_addr(addr)
189    }
190
191    fn generate_oauth_token(
192        &self,
193        oauthbearer_config: Option<&str>,
194    ) -> Result<OAuthToken, Box<dyn std::error::Error>> {
195        self.common.generate_oauth_token(oauthbearer_config)
196    }
197
198    fn enable_refresh_oauth_token(&self) -> bool {
199        self.common.enable_refresh_oauth_token()
200    }
201}
202
203// required by the trait bound of BaseConsumer
204impl ConsumerContext for RwConsumerContext {}
205
206/// Kafka producer context used for private link, IAM auth, and metrics
207pub struct RwProducerContext {
208    common: KafkaContextCommon,
209}
210
211impl RwProducerContext {
212    pub fn new(common: KafkaContextCommon) -> Self {
213        Self { common }
214    }
215}
216
217impl ClientContext for RwProducerContext {
218    fn stats(&self, statistics: Statistics) {
219        self.common.stats(statistics);
220    }
221
222    fn rewrite_broker_addr(&self, addr: BrokerAddr) -> BrokerAddr {
223        self.common.rewrite_broker_addr(addr)
224    }
225
226    fn generate_oauth_token(
227        &self,
228        oauthbearer_config: Option<&str>,
229    ) -> Result<OAuthToken, Box<dyn std::error::Error>> {
230        self.common.generate_oauth_token(oauthbearer_config)
231    }
232
233    fn enable_refresh_oauth_token(&self) -> bool {
234        self.common.enable_refresh_oauth_token()
235    }
236}
237
238impl ProducerContext for RwProducerContext {
239    type DeliveryOpaque = ();
240
241    fn delivery(&self, _: &DeliveryResult<'_>, _: Self::DeliveryOpaque) {}
242}