risingwave_connector/schema/schema_registry/
client.rs1use std::collections::HashSet;
16use std::fmt::Debug;
17use std::sync::Arc;
18use std::time::Duration;
19
20use futures::future::select_all;
21use itertools::Itertools;
22use reqwest::{Method, Url};
23use serde::Deserialize;
24use serde::de::DeserializeOwned;
25use thiserror_ext::AsReport as _;
26use tokio_retry::Retry;
27use tokio_retry::strategy::{ExponentialBackoff, jitter};
28
29use super::util::*;
30use crate::connector_common::ConfluentSchemaRegistryConnection;
31use crate::schema::{InvalidOptionError, invalid_option_error};
32use crate::with_options::Get;
33
34pub const SCHEMA_REGISTRY_USERNAME: &str = "schema.registry.username";
35pub const SCHEMA_REGISTRY_PASSWORD: &str = "schema.registry.password";
36pub const SCHEMA_REGISTRY_CA_PEM_PATH: &str = "schema.registry.ca_pem_path";
37
38pub const SCHEMA_REGISTRY_MAX_DELAY_KEY: &str = "schema.registry.max.delay.sec";
39pub const SCHEMA_REGISTRY_BACKOFF_DURATION_KEY: &str = "schema.registry.backoff.duration.ms";
40pub const SCHEMA_REGISTRY_BACKOFF_FACTOR_KEY: &str = "schema.registry.backoff.factor";
41pub const SCHEMA_REGISTRY_RETRIES_MAX_KEY: &str = "schema.registry.retries.max";
42
43const DEFAULT_MAX_DELAY_SEC: u32 = 3;
44const DEFAULT_BACKOFF_DURATION_MS: u64 = 100;
45const DEFAULT_BACKOFF_FACTOR: u64 = 2;
46const DEFAULT_RETRIES_MAX: usize = 3;
47
48#[derive(Debug, Clone)]
49struct SchemaRegistryRetryConfig {
50 pub max_delay_sec: u32,
51 pub backoff_duration_ms: u64,
52 pub backoff_factor: u64,
53 pub retries_max: usize,
54}
55
56impl Default for SchemaRegistryRetryConfig {
57 fn default() -> Self {
58 Self {
59 max_delay_sec: DEFAULT_MAX_DELAY_SEC,
60 backoff_duration_ms: DEFAULT_BACKOFF_DURATION_MS,
61 backoff_factor: DEFAULT_BACKOFF_FACTOR,
62 retries_max: DEFAULT_RETRIES_MAX,
63 }
64 }
65}
66
67#[derive(Debug, Clone, Default)]
68pub struct SchemaRegistryConfig {
69 username: Option<String>,
70 password: Option<String>,
71 ca_pem_path: Option<String>,
72
73 retry_config: SchemaRegistryRetryConfig,
74}
75
76impl<T: Get> From<&T> for SchemaRegistryConfig {
77 fn from(props: &T) -> Self {
78 SchemaRegistryConfig {
79 username: props.get(SCHEMA_REGISTRY_USERNAME).cloned(),
80 password: props.get(SCHEMA_REGISTRY_PASSWORD).cloned(),
81 ca_pem_path: props.get(SCHEMA_REGISTRY_CA_PEM_PATH).cloned(),
82
83 retry_config: SchemaRegistryRetryConfig {
84 max_delay_sec: props
85 .get(SCHEMA_REGISTRY_MAX_DELAY_KEY)
86 .and_then(|v| v.parse::<u32>().ok())
87 .unwrap_or(DEFAULT_MAX_DELAY_SEC),
88 backoff_duration_ms: props
89 .get(SCHEMA_REGISTRY_BACKOFF_DURATION_KEY)
90 .and_then(|v| v.parse::<u64>().ok())
91 .unwrap_or(DEFAULT_BACKOFF_DURATION_MS),
92 backoff_factor: props
93 .get(SCHEMA_REGISTRY_BACKOFF_FACTOR_KEY)
94 .and_then(|v| v.parse::<u64>().ok())
95 .unwrap_or(DEFAULT_BACKOFF_FACTOR),
96 retries_max: props
97 .get(SCHEMA_REGISTRY_RETRIES_MAX_KEY)
98 .and_then(|v| v.parse::<usize>().ok())
99 .unwrap_or(DEFAULT_RETRIES_MAX),
100 },
101 }
102 }
103}
104
105#[derive(Debug)]
107pub struct Client {
108 inner: reqwest::Client,
109 url: Vec<Url>,
110 username: Option<String>,
111 password: Option<String>,
112
113 retry_config: SchemaRegistryRetryConfig,
114}
115
116#[derive(Debug, thiserror::Error)]
117#[error("all request confluent registry all timeout, {context}\n{}", errs.iter().map(|e| format!("\t{}", e.as_report())).join("\n"))]
118pub struct ConcurrentRequestError {
119 errs: Vec<itertools::Either<RequestError, tokio::task::JoinError>>,
120 context: String,
121}
122
123type SrResult<T> = Result<T, ConcurrentRequestError>;
124
125#[derive(thiserror::Error, Debug)]
126pub enum SchemaRegistryClientError {
127 #[error(transparent)]
128 InvalidOption(#[from] InvalidOptionError),
129 #[error("read ca file error: {0}")]
130 ReadFile(#[source] std::io::Error),
131 #[error("parse ca file error: {0}")]
132 ParsePem(#[source] reqwest::Error),
133 #[error("build schema registry client error: {0}")]
134 Build(#[source] reqwest::Error),
135}
136
137impl TryFrom<&ConfluentSchemaRegistryConnection> for Client {
138 type Error = InvalidOptionError;
139
140 fn try_from(value: &ConfluentSchemaRegistryConnection) -> Result<Self, Self::Error> {
141 let urls = handle_sr_list(value.url.as_str())?;
142
143 Client::new(
144 urls,
145 &SchemaRegistryConfig {
146 username: value.username.clone(),
147 password: value.password.clone(),
148 ..Default::default()
149 },
150 )
151 .map_err(|e| match e {
152 SchemaRegistryClientError::InvalidOption(e) => e,
153 e => {
154 invalid_option_error!("failed to create schema registry client: {}", e.as_report())
155 }
156 })
157 }
158}
159
160impl Client {
161 pub(crate) fn new(
162 url: Vec<Url>,
163 client_config: &SchemaRegistryConfig,
164 ) -> Result<Self, SchemaRegistryClientError> {
165 let valid_urls = url
166 .iter()
167 .map(|url| (url.cannot_be_a_base(), url))
168 .filter(|(x, _)| !*x)
169 .map(|(_, url)| url.clone())
170 .collect_vec();
171 if valid_urls.is_empty() {
172 return Err(SchemaRegistryClientError::InvalidOption(
173 invalid_option_error!("non-base: {}", url.iter().join(" ")),
174 ));
175 } else {
176 tracing::debug!(
177 "schema registry client will use url {:?} to connect",
178 valid_urls
179 );
180 }
181
182 let mut client_builder = reqwest::Client::builder();
183 if let Some(ca_path) = client_config.ca_pem_path.as_ref() {
184 if ca_path.eq_ignore_ascii_case("ignore") {
185 client_builder = client_builder.danger_accept_invalid_certs(true);
186 } else {
187 client_builder = client_builder.add_root_certificate(
188 reqwest::Certificate::from_pem(
189 &std::fs::read(ca_path).map_err(SchemaRegistryClientError::ReadFile)?,
190 )
191 .map_err(SchemaRegistryClientError::ParsePem)?,
192 );
193 }
194 }
195
196 let inner = client_builder
197 .build()
198 .map_err(SchemaRegistryClientError::Build)?;
199
200 Ok(Client {
201 inner,
202 url: valid_urls,
203 username: client_config.username.clone(),
204 password: client_config.password.clone(),
205 retry_config: client_config.retry_config.clone(),
206 })
207 }
208
209 async fn concurrent_req<'a, T>(
210 &'a self,
211 method: Method,
212 path: &'a [&'a (impl AsRef<str> + ?Sized + Debug + ToString)],
213 ) -> SrResult<T>
214 where
215 T: DeserializeOwned + Send + Sync + 'static,
216 {
217 let mut fut_req = Vec::with_capacity(self.url.len());
218 let mut errs = Vec::with_capacity(self.url.len());
219 let ctx = Arc::new(SchemaRegistryCtx {
220 username: self.username.clone(),
221 password: self.password.clone(),
222 client: self.inner.clone(),
223 path: path.iter().map(|p| p.to_string()).collect_vec(),
224 });
225 tracing::debug!("retry config: {:?}", self.retry_config);
226
227 let retry_strategy = ExponentialBackoff::from_millis(self.retry_config.backoff_duration_ms)
228 .factor(self.retry_config.backoff_factor)
229 .max_delay(Duration::from_secs(self.retry_config.max_delay_sec as u64))
230 .take(self.retry_config.retries_max)
231 .map(jitter);
232
233 for url in &self.url {
234 let url_clone = url.clone();
235 let ctx_clone = ctx.clone();
236 let method_clone = method.clone();
237
238 let retry_future = Retry::spawn(retry_strategy.clone(), move || {
239 let ctx = ctx_clone.clone();
240 let url = url_clone.clone();
241 let method = method_clone.clone();
242 async move { req_inner(ctx, url, method).await }
243 });
244
245 fut_req.push(tokio::spawn(retry_future));
246 }
247
248 while !fut_req.is_empty() {
249 let (result, _index, remaining) = select_all(fut_req).await;
250 match result {
251 Ok(Ok(res)) => {
252 let _ = remaining.iter().map(|ele| ele.abort());
253 return Ok(res);
254 }
255 Ok(Err(e)) => errs.push(itertools::Either::Left(e)),
256 Err(e) => errs.push(itertools::Either::Right(e)),
257 }
258 fut_req = remaining;
259 }
260
261 Err(ConcurrentRequestError {
262 errs,
263 context: format!("req path {:?}, urls {}", path, self.url.iter().join(" ")),
264 })
265 }
266
267 pub async fn get_schema_by_id(&self, id: i32) -> SrResult<ConfluentSchema> {
269 let res: GetByIdResp = self
270 .concurrent_req(Method::GET, &["schemas", "ids", &id.to_string()])
271 .await?;
272 Ok(ConfluentSchema {
273 id,
274 content: res.schema,
275 })
276 }
277
278 pub async fn get_schema_by_subject(&self, subject: &str) -> SrResult<ConfluentSchema> {
280 self.get_subject(subject).await.map(|s| s.schema)
281 }
282
283 pub async fn validate_connection(&self) -> SrResult<()> {
285 #[derive(Debug, Deserialize)]
286 struct GetConfigResp {
287 #[serde(rename = "compatibilityLevel")]
288 _compatibility_level: String,
289 }
290
291 let _: GetConfigResp = self.concurrent_req(Method::GET, &["config"]).await?;
292 Ok(())
293 }
294
295 pub async fn get_subject(&self, subject: &str) -> SrResult<Subject> {
297 let res: GetBySubjectResp = self
298 .concurrent_req(Method::GET, &["subjects", subject, "versions", "latest"])
299 .await?;
300 tracing::debug!("update schema: {:?}", res);
301 Ok(Subject {
302 schema: ConfluentSchema {
303 id: res.id,
304 content: res.schema,
305 },
306 version: res.version,
307 name: res.subject,
308 })
309 }
310
311 pub async fn get_subject_and_references(
313 &self,
314 subject: &str,
315 ) -> SrResult<(Subject, Vec<Subject>)> {
316 let mut subjects = vec![];
317 let mut visited = HashSet::new();
318 let mut queue = vec![(subject.to_owned(), "latest".to_owned())];
319 while let Some((subject, version)) = queue.pop() {
321 let res: GetBySubjectResp = self
322 .concurrent_req(Method::GET, &["subjects", &subject, "versions", &version])
323 .await?;
324 let ref_subject = Subject {
325 schema: ConfluentSchema {
326 id: res.id,
327 content: res.schema,
328 },
329 version: res.version,
330 name: res.subject.clone(),
331 };
332 subjects.push(ref_subject);
333 visited.insert(res.subject);
334 queue.extend(
335 res.references
336 .into_iter()
337 .filter(|r| !visited.contains(&r.subject))
338 .map(|r| (r.subject, r.version.to_string())),
339 );
340 }
341 let origin_subject = subjects.remove(0);
342
343 Ok((origin_subject, subjects))
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[tokio::test]
352 #[ignore]
353 async fn test_get_subject() {
354 let url = Url::parse("http://localhost:8081").unwrap();
355 let client = Client::new(
356 vec![url],
357 &SchemaRegistryConfig {
358 username: None,
359 password: None,
360 ca_pem_path: None,
361 retry_config: SchemaRegistryRetryConfig::default(),
362 },
363 )
364 .unwrap();
365 let subject = client
366 .get_subject_and_references("proto_c_bin-value")
367 .await
368 .unwrap();
369 println!("{:?}", subject);
370 }
371}