risingwave_connector/schema/schema_registry/
client.rs
1use std::collections::HashSet;
16use std::fmt::Debug;
17use std::sync::Arc;
18use std::time::Duration;
19
20use futures::future::select_all;
21use itertools::Itertools;
22use reqwest::{Method, Url};
23use serde::Deserialize;
24use serde::de::DeserializeOwned;
25use thiserror_ext::AsReport as _;
26use tokio_retry::Retry;
27use tokio_retry::strategy::{ExponentialBackoff, jitter};
28
29use super::util::*;
30use crate::connector_common::ConfluentSchemaRegistryConnection;
31use crate::schema::{InvalidOptionError, invalid_option_error};
32use crate::with_options::Get;
33
34pub const SCHEMA_REGISTRY_USERNAME: &str = "schema.registry.username";
35pub const SCHEMA_REGISTRY_PASSWORD: &str = "schema.registry.password";
36
37pub const SCHEMA_REGISTRY_MAX_DELAY_KEY: &str = "schema.registry.max.delay.sec";
38pub const SCHEMA_REGISTRY_BACKOFF_DURATION_KEY: &str = "schema.registry.backoff.duration.ms";
39pub const SCHEMA_REGISTRY_BACKOFF_FACTOR_KEY: &str = "schema.registry.backoff.factor";
40pub const SCHEMA_REGISTRY_RETRIES_MAX_KEY: &str = "schema.registry.retries.max";
41
42const DEFAULT_MAX_DELAY_SEC: u32 = 3;
43const DEFAULT_BACKOFF_DURATION_MS: u64 = 100;
44const DEFAULT_BACKOFF_FACTOR: u64 = 2;
45const DEFAULT_RETRIES_MAX: usize = 3;
46
47#[derive(Debug, Clone)]
48struct SchemaRegistryRetryConfig {
49 pub max_delay_sec: u32,
50 pub backoff_duration_ms: u64,
51 pub backoff_factor: u64,
52 pub retries_max: usize,
53}
54
55impl Default for SchemaRegistryRetryConfig {
56 fn default() -> Self {
57 Self {
58 max_delay_sec: DEFAULT_MAX_DELAY_SEC,
59 backoff_duration_ms: DEFAULT_BACKOFF_DURATION_MS,
60 backoff_factor: DEFAULT_BACKOFF_FACTOR,
61 retries_max: DEFAULT_RETRIES_MAX,
62 }
63 }
64}
65
66#[derive(Debug, Clone, Default)]
67pub struct SchemaRegistryConfig {
68 username: Option<String>,
69 password: Option<String>,
70
71 retry_config: SchemaRegistryRetryConfig,
72}
73
74impl<T: Get> From<&T> for SchemaRegistryConfig {
75 fn from(props: &T) -> Self {
76 SchemaRegistryConfig {
77 username: props.get(SCHEMA_REGISTRY_USERNAME).cloned(),
78 password: props.get(SCHEMA_REGISTRY_PASSWORD).cloned(),
79
80 retry_config: SchemaRegistryRetryConfig {
81 max_delay_sec: props
82 .get(SCHEMA_REGISTRY_MAX_DELAY_KEY)
83 .and_then(|v| v.parse::<u32>().ok())
84 .unwrap_or(DEFAULT_MAX_DELAY_SEC),
85 backoff_duration_ms: props
86 .get(SCHEMA_REGISTRY_BACKOFF_DURATION_KEY)
87 .and_then(|v| v.parse::<u64>().ok())
88 .unwrap_or(DEFAULT_BACKOFF_DURATION_MS),
89 backoff_factor: props
90 .get(SCHEMA_REGISTRY_BACKOFF_FACTOR_KEY)
91 .and_then(|v| v.parse::<u64>().ok())
92 .unwrap_or(DEFAULT_BACKOFF_FACTOR),
93 retries_max: props
94 .get(SCHEMA_REGISTRY_RETRIES_MAX_KEY)
95 .and_then(|v| v.parse::<usize>().ok())
96 .unwrap_or(DEFAULT_RETRIES_MAX),
97 },
98 }
99 }
100}
101
102#[derive(Debug)]
104pub struct Client {
105 inner: reqwest::Client,
106 url: Vec<Url>,
107 username: Option<String>,
108 password: Option<String>,
109
110 retry_config: SchemaRegistryRetryConfig,
111}
112
113#[derive(Debug, thiserror::Error)]
114#[error("all request confluent registry all timeout, {context}\n{}", errs.iter().map(|e| format!("\t{}", e.as_report())).join("\n"))]
115pub struct ConcurrentRequestError {
116 errs: Vec<itertools::Either<RequestError, tokio::task::JoinError>>,
117 context: String,
118}
119
120type SrResult<T> = Result<T, ConcurrentRequestError>;
121
122impl TryFrom<&ConfluentSchemaRegistryConnection> for Client {
123 type Error = InvalidOptionError;
124
125 fn try_from(value: &ConfluentSchemaRegistryConnection) -> Result<Self, Self::Error> {
126 let urls = handle_sr_list(value.url.as_str())?;
127
128 Client::new(
129 urls,
130 &SchemaRegistryConfig {
131 username: value.username.clone(),
132 password: value.password.clone(),
133 ..Default::default()
134 },
135 )
136 }
137}
138
139impl Client {
140 pub(crate) fn new(
141 url: Vec<Url>,
142 client_config: &SchemaRegistryConfig,
143 ) -> Result<Self, InvalidOptionError> {
144 let valid_urls = url
145 .iter()
146 .map(|url| (url.cannot_be_a_base(), url))
147 .filter(|(x, _)| !*x)
148 .map(|(_, url)| url.clone())
149 .collect_vec();
150 if valid_urls.is_empty() {
151 return Err(invalid_option_error!("non-base: {}", url.iter().join(" ")));
152 } else {
153 tracing::debug!(
154 "schema registry client will use url {:?} to connect",
155 valid_urls
156 );
157 }
158
159 let inner = reqwest::Client::builder().build().unwrap();
161
162 Ok(Client {
163 inner,
164 url: valid_urls,
165 username: client_config.username.clone(),
166 password: client_config.password.clone(),
167 retry_config: client_config.retry_config.clone(),
168 })
169 }
170
171 async fn concurrent_req<'a, T>(
172 &'a self,
173 method: Method,
174 path: &'a [&'a (impl AsRef<str> + ?Sized + Debug + ToString)],
175 ) -> SrResult<T>
176 where
177 T: DeserializeOwned + Send + Sync + 'static,
178 {
179 let mut fut_req = Vec::with_capacity(self.url.len());
180 let mut errs = Vec::with_capacity(self.url.len());
181 let ctx = Arc::new(SchemaRegistryCtx {
182 username: self.username.clone(),
183 password: self.password.clone(),
184 client: self.inner.clone(),
185 path: path.iter().map(|p| p.to_string()).collect_vec(),
186 });
187 tracing::debug!("retry config: {:?}", self.retry_config);
188
189 let retry_strategy = ExponentialBackoff::from_millis(self.retry_config.backoff_duration_ms)
190 .factor(self.retry_config.backoff_factor)
191 .max_delay(Duration::from_secs(self.retry_config.max_delay_sec as u64))
192 .take(self.retry_config.retries_max)
193 .map(jitter);
194
195 for url in &self.url {
196 let url_clone = url.clone();
197 let ctx_clone = ctx.clone();
198 let method_clone = method.clone();
199
200 let retry_future = Retry::spawn(retry_strategy.clone(), move || {
201 let ctx = ctx_clone.clone();
202 let url = url_clone.clone();
203 let method = method_clone.clone();
204 async move { req_inner(ctx, url, method).await }
205 });
206
207 fut_req.push(tokio::spawn(retry_future));
208 }
209
210 while !fut_req.is_empty() {
211 let (result, _index, remaining) = select_all(fut_req).await;
212 match result {
213 Ok(Ok(res)) => {
214 let _ = remaining.iter().map(|ele| ele.abort());
215 return Ok(res);
216 }
217 Ok(Err(e)) => errs.push(itertools::Either::Left(e)),
218 Err(e) => errs.push(itertools::Either::Right(e)),
219 }
220 fut_req = remaining;
221 }
222
223 Err(ConcurrentRequestError {
224 errs,
225 context: format!("req path {:?}, urls {}", path, self.url.iter().join(" ")),
226 })
227 }
228
229 pub async fn get_schema_by_id(&self, id: i32) -> SrResult<ConfluentSchema> {
231 let res: GetByIdResp = self
232 .concurrent_req(Method::GET, &["schemas", "ids", &id.to_string()])
233 .await?;
234 Ok(ConfluentSchema {
235 id,
236 content: res.schema,
237 })
238 }
239
240 pub async fn get_schema_by_subject(&self, subject: &str) -> SrResult<ConfluentSchema> {
242 self.get_subject(subject).await.map(|s| s.schema)
243 }
244
245 pub async fn validate_connection(&self) -> SrResult<()> {
247 #[derive(Debug, Deserialize)]
248 struct GetConfigResp {
249 #[serde(rename = "compatibilityLevel")]
250 _compatibility_level: String,
251 }
252
253 let _: GetConfigResp = self.concurrent_req(Method::GET, &["config"]).await?;
254 Ok(())
255 }
256
257 pub async fn get_subject(&self, subject: &str) -> SrResult<Subject> {
259 let res: GetBySubjectResp = self
260 .concurrent_req(Method::GET, &["subjects", subject, "versions", "latest"])
261 .await?;
262 tracing::debug!("update schema: {:?}", res);
263 Ok(Subject {
264 schema: ConfluentSchema {
265 id: res.id,
266 content: res.schema,
267 },
268 version: res.version,
269 name: res.subject,
270 })
271 }
272
273 pub async fn get_subject_and_references(
275 &self,
276 subject: &str,
277 ) -> SrResult<(Subject, Vec<Subject>)> {
278 let mut subjects = vec![];
279 let mut visited = HashSet::new();
280 let mut queue = vec![(subject.to_owned(), "latest".to_owned())];
281 while let Some((subject, version)) = queue.pop() {
283 let res: GetBySubjectResp = self
284 .concurrent_req(Method::GET, &["subjects", &subject, "versions", &version])
285 .await?;
286 let ref_subject = Subject {
287 schema: ConfluentSchema {
288 id: res.id,
289 content: res.schema,
290 },
291 version: res.version,
292 name: res.subject.clone(),
293 };
294 subjects.push(ref_subject);
295 visited.insert(res.subject);
296 queue.extend(
297 res.references
298 .into_iter()
299 .filter(|r| !visited.contains(&r.subject))
300 .map(|r| (r.subject, r.version.to_string())),
301 );
302 }
303 let origin_subject = subjects.remove(0);
304
305 Ok((origin_subject, subjects))
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[tokio::test]
314 #[ignore]
315 async fn test_get_subject() {
316 let url = Url::parse("http://localhost:8081").unwrap();
317 let client = Client::new(
318 vec![url],
319 &SchemaRegistryConfig {
320 username: None,
321 password: None,
322 retry_config: SchemaRegistryRetryConfig::default(),
323 },
324 )
325 .unwrap();
326 let subject = client
327 .get_subject_and_references("proto_c_bin-value")
328 .await
329 .unwrap();
330 println!("{:?}", subject);
331 }
332}