risingwave_connector/source/kafka/
client_context.rs1use std::collections::BTreeMap;
16use std::sync::Arc;
17use std::thread;
18
19use anyhow::anyhow;
20use aws_config::Region;
21use aws_sdk_s3::config::SharedCredentialsProvider;
22use rdkafka::client::{BrokerAddr, OAuthToken};
23use rdkafka::consumer::ConsumerContext;
24use rdkafka::message::DeliveryResult;
25use rdkafka::producer::ProducerContext;
26use rdkafka::{ClientContext, Statistics};
27
28use super::private_link::{BrokerAddrRewriter, PrivateLinkContextRole};
29use super::stats::RdKafkaStats;
30use crate::connector_common::AwsAuthProps;
31use crate::error::ConnectorResult;
32
33struct IamAuthEnv {
34 credentials_provider: SharedCredentialsProvider,
35 region: Region,
36 #[cfg(not(madsim))]
38 rt: tokio::runtime::Handle,
39 signer_timeout_sec: u64,
40}
41
42pub struct KafkaContextCommon {
43 addr_rewriter: BrokerAddrRewriter,
45
46 identifier: Option<String>,
50 metrics: Option<Arc<RdKafkaStats>>,
51
52 auth: Option<IamAuthEnv>,
54}
55
56impl KafkaContextCommon {
57 pub async fn new(
58 broker_rewrite_map: Option<BTreeMap<String, String>>,
59 identifier: Option<String>,
60 metrics: Option<Arc<RdKafkaStats>>,
61 auth: AwsAuthProps,
62 is_aws_msk_iam: bool,
63 ) -> ConnectorResult<Self> {
64 let addr_rewriter =
65 BrokerAddrRewriter::new(PrivateLinkContextRole::Consumer, broker_rewrite_map)?;
66 let auth = if is_aws_msk_iam {
67 let config = auth.build_config().await?;
68 let credentials_provider = config
69 .credentials_provider()
70 .ok_or_else(|| anyhow!("missing aws credentials_provider"))?;
71 let region = config
72 .region()
73 .ok_or_else(|| anyhow!("missing aws region"))?
74 .clone();
75 Some(IamAuthEnv {
76 credentials_provider,
77 region,
78 #[cfg(not(madsim))]
79 rt: tokio::runtime::Handle::current(),
80 signer_timeout_sec: auth
81 .msk_signer_timeout_sec
82 .unwrap_or(Self::default_msk_signer_timeout_sec()),
83 })
84 } else {
85 None
86 };
87 Ok(Self {
88 addr_rewriter,
89 identifier,
90 metrics,
91 auth,
92 })
93 }
94
95 fn default_msk_signer_timeout_sec() -> u64 {
96 10
97 }
98}
99
100impl KafkaContextCommon {
101 fn stats(&self, statistics: Statistics) {
102 if let Some(metrics) = &self.metrics
103 && let Some(id) = &self.identifier
104 {
105 metrics.report(id.as_str(), &statistics);
106 }
107 }
108
109 fn rewrite_broker_addr(&self, addr: BrokerAddr) -> BrokerAddr {
110 self.addr_rewriter.rewrite_broker_addr(addr)
111 }
112
113 #[cfg_or_panic::cfg_or_panic(not(madsim))]
115 fn generate_oauth_token(
116 &self,
117 _oauthbearer_config: Option<&str>,
118 ) -> Result<OAuthToken, Box<dyn std::error::Error>> {
119 use aws_msk_iam_sasl_signer::generate_auth_token_from_credentials_provider;
120 use tokio::time::{Duration, timeout};
121
122 if let Some(IamAuthEnv {
123 credentials_provider,
124 region,
125 rt,
126 signer_timeout_sec,
127 }) = &self.auth
128 {
129 let region = region.clone();
130 let credentials_provider = credentials_provider.clone();
131 let rt = rt.clone();
132 let signer_timeout_sec = *signer_timeout_sec;
133 let (token, expiration_time_ms) = {
134 let handle = thread::spawn(move || {
135 rt.block_on(async {
136 timeout(
137 Duration::from_secs(signer_timeout_sec),
138 generate_auth_token_from_credentials_provider(
139 region,
140 credentials_provider,
141 ),
142 )
143 .await
144 })
145 });
146 handle.join().unwrap()??
147 };
148 Ok(OAuthToken {
149 token,
150 principal_name: "".to_owned(),
151 lifetime_ms: expiration_time_ms,
152 })
153 } else {
154 Err("must provide AWS IAM credential".into())
155 }
156 }
157
158 fn enable_refresh_oauth_token(&self) -> bool {
159 self.auth.is_some()
160 }
161}
162
163pub type BoxConsumerContext = Box<dyn ConsumerContext>;
164
165pub struct RwConsumerContext {
167 common: KafkaContextCommon,
168}
169
170impl RwConsumerContext {
171 pub fn new(common: KafkaContextCommon) -> Self {
172 Self { common }
173 }
174}
175
176impl ClientContext for RwConsumerContext {
177 fn stats(&self, statistics: Statistics) {
179 self.common.stats(statistics);
180 }
181
182 fn rewrite_broker_addr(&self, addr: BrokerAddr) -> BrokerAddr {
183 self.common.rewrite_broker_addr(addr)
184 }
185
186 fn generate_oauth_token(
187 &self,
188 oauthbearer_config: Option<&str>,
189 ) -> Result<OAuthToken, Box<dyn std::error::Error>> {
190 self.common.generate_oauth_token(oauthbearer_config)
191 }
192
193 fn enable_refresh_oauth_token(&self) -> bool {
194 self.common.enable_refresh_oauth_token()
195 }
196}
197
198impl ConsumerContext for RwConsumerContext {}
200
201pub struct RwProducerContext {
203 common: KafkaContextCommon,
204}
205
206impl RwProducerContext {
207 pub fn new(common: KafkaContextCommon) -> Self {
208 Self { common }
209 }
210}
211
212impl ClientContext for RwProducerContext {
213 fn stats(&self, statistics: Statistics) {
214 self.common.stats(statistics);
215 }
216
217 fn rewrite_broker_addr(&self, addr: BrokerAddr) -> BrokerAddr {
218 self.common.rewrite_broker_addr(addr)
219 }
220
221 fn generate_oauth_token(
222 &self,
223 oauthbearer_config: Option<&str>,
224 ) -> Result<OAuthToken, Box<dyn std::error::Error>> {
225 self.common.generate_oauth_token(oauthbearer_config)
226 }
227
228 fn enable_refresh_oauth_token(&self) -> bool {
229 self.common.enable_refresh_oauth_token()
230 }
231}
232
233impl ProducerContext for RwProducerContext {
234 type DeliveryOpaque = ();
235
236 fn delivery(&self, _: &DeliveryResult<'_>, _: Self::DeliveryOpaque) {}
237}