risingwave_connector/schema/schema_registry/
client.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::fmt::Debug;
17use std::sync::Arc;
18use std::time::Duration;
19
20use futures::future::select_all;
21use itertools::Itertools;
22use reqwest::{Method, Url};
23use serde::Deserialize;
24use serde::de::DeserializeOwned;
25use thiserror_ext::AsReport as _;
26use tokio_retry::Retry;
27use tokio_retry::strategy::{ExponentialBackoff, jitter};
28
29use super::util::*;
30use crate::connector_common::ConfluentSchemaRegistryConnection;
31use crate::schema::{InvalidOptionError, invalid_option_error};
32use crate::with_options::Get;
33
34pub const SCHEMA_REGISTRY_USERNAME: &str = "schema.registry.username";
35pub const SCHEMA_REGISTRY_PASSWORD: &str = "schema.registry.password";
36pub const SCHEMA_REGISTRY_CA_PEM_PATH: &str = "schema.registry.ca_pem_path";
37
38pub const SCHEMA_REGISTRY_MAX_DELAY_KEY: &str = "schema.registry.max.delay.sec";
39pub const SCHEMA_REGISTRY_BACKOFF_DURATION_KEY: &str = "schema.registry.backoff.duration.ms";
40pub const SCHEMA_REGISTRY_BACKOFF_FACTOR_KEY: &str = "schema.registry.backoff.factor";
41pub const SCHEMA_REGISTRY_RETRIES_MAX_KEY: &str = "schema.registry.retries.max";
42
43const DEFAULT_MAX_DELAY_SEC: u32 = 3;
44const DEFAULT_BACKOFF_DURATION_MS: u64 = 100;
45const DEFAULT_BACKOFF_FACTOR: u64 = 2;
46const DEFAULT_RETRIES_MAX: usize = 3;
47
48#[derive(Debug, Clone)]
49struct SchemaRegistryRetryConfig {
50    pub max_delay_sec: u32,
51    pub backoff_duration_ms: u64,
52    pub backoff_factor: u64,
53    pub retries_max: usize,
54}
55
56impl Default for SchemaRegistryRetryConfig {
57    fn default() -> Self {
58        Self {
59            max_delay_sec: DEFAULT_MAX_DELAY_SEC,
60            backoff_duration_ms: DEFAULT_BACKOFF_DURATION_MS,
61            backoff_factor: DEFAULT_BACKOFF_FACTOR,
62            retries_max: DEFAULT_RETRIES_MAX,
63        }
64    }
65}
66
67#[derive(Debug, Clone, Default)]
68pub struct SchemaRegistryConfig {
69    username: Option<String>,
70    password: Option<String>,
71    ca_pem_path: Option<String>,
72
73    retry_config: SchemaRegistryRetryConfig,
74}
75
76impl<T: Get> From<&T> for SchemaRegistryConfig {
77    fn from(props: &T) -> Self {
78        SchemaRegistryConfig {
79            username: props.get(SCHEMA_REGISTRY_USERNAME).cloned(),
80            password: props.get(SCHEMA_REGISTRY_PASSWORD).cloned(),
81            ca_pem_path: props.get(SCHEMA_REGISTRY_CA_PEM_PATH).cloned(),
82
83            retry_config: SchemaRegistryRetryConfig {
84                max_delay_sec: props
85                    .get(SCHEMA_REGISTRY_MAX_DELAY_KEY)
86                    .and_then(|v| v.parse::<u32>().ok())
87                    .unwrap_or(DEFAULT_MAX_DELAY_SEC),
88                backoff_duration_ms: props
89                    .get(SCHEMA_REGISTRY_BACKOFF_DURATION_KEY)
90                    .and_then(|v| v.parse::<u64>().ok())
91                    .unwrap_or(DEFAULT_BACKOFF_DURATION_MS),
92                backoff_factor: props
93                    .get(SCHEMA_REGISTRY_BACKOFF_FACTOR_KEY)
94                    .and_then(|v| v.parse::<u64>().ok())
95                    .unwrap_or(DEFAULT_BACKOFF_FACTOR),
96                retries_max: props
97                    .get(SCHEMA_REGISTRY_RETRIES_MAX_KEY)
98                    .and_then(|v| v.parse::<usize>().ok())
99                    .unwrap_or(DEFAULT_RETRIES_MAX),
100            },
101        }
102    }
103}
104
105/// An client for communication with schema registry
106#[derive(Debug)]
107pub struct Client {
108    inner: reqwest::Client,
109    url: Vec<Url>,
110    username: Option<String>,
111    password: Option<String>,
112
113    retry_config: SchemaRegistryRetryConfig,
114}
115
116#[derive(Debug, thiserror::Error)]
117#[error("all request confluent registry all timeout, {context}\n{}", errs.iter().map(|e| format!("\t{}", e.as_report())).join("\n"))]
118pub struct ConcurrentRequestError {
119    errs: Vec<itertools::Either<RequestError, tokio::task::JoinError>>,
120    context: String,
121}
122
123type SrResult<T> = Result<T, ConcurrentRequestError>;
124
125#[derive(thiserror::Error, Debug)]
126pub enum SchemaRegistryClientError {
127    #[error(transparent)]
128    InvalidOption(#[from] InvalidOptionError),
129    #[error("read ca file error: {0}")]
130    ReadFile(#[source] std::io::Error),
131    #[error("parse ca file error: {0}")]
132    ParsePem(#[source] reqwest::Error),
133    #[error("build schema registry client error: {0}")]
134    Build(#[source] reqwest::Error),
135}
136
137impl TryFrom<&ConfluentSchemaRegistryConnection> for Client {
138    type Error = InvalidOptionError;
139
140    fn try_from(value: &ConfluentSchemaRegistryConnection) -> Result<Self, Self::Error> {
141        let urls = handle_sr_list(value.url.as_str())?;
142
143        Client::new(
144            urls,
145            &SchemaRegistryConfig {
146                username: value.username.clone(),
147                password: value.password.clone(),
148                ..Default::default()
149            },
150        )
151        .map_err(|e| match e {
152            SchemaRegistryClientError::InvalidOption(e) => e,
153            e => {
154                invalid_option_error!("failed to create schema registry client: {}", e.as_report())
155            }
156        })
157    }
158}
159
160impl Client {
161    pub(crate) fn new(
162        url: Vec<Url>,
163        client_config: &SchemaRegistryConfig,
164    ) -> Result<Self, SchemaRegistryClientError> {
165        let valid_urls = url
166            .iter()
167            .map(|url| (url.cannot_be_a_base(), url))
168            .filter(|(x, _)| !*x)
169            .map(|(_, url)| url.clone())
170            .collect_vec();
171        if valid_urls.is_empty() {
172            return Err(SchemaRegistryClientError::InvalidOption(
173                invalid_option_error!("non-base: {}", url.iter().join(" ")),
174            ));
175        } else {
176            tracing::debug!(
177                "schema registry client will use url {:?} to connect",
178                valid_urls
179            );
180        }
181
182        let mut client_builder = reqwest::Client::builder();
183        if let Some(ca_path) = client_config.ca_pem_path.as_ref() {
184            if ca_path.eq_ignore_ascii_case("ignore") {
185                client_builder = client_builder.danger_accept_invalid_certs(true);
186            } else {
187                client_builder = client_builder.add_root_certificate(
188                    reqwest::Certificate::from_pem(
189                        &std::fs::read(ca_path).map_err(SchemaRegistryClientError::ReadFile)?,
190                    )
191                    .map_err(SchemaRegistryClientError::ParsePem)?,
192                );
193            }
194        }
195
196        let inner = client_builder
197            .build()
198            .map_err(SchemaRegistryClientError::Build)?;
199
200        Ok(Client {
201            inner,
202            url: valid_urls,
203            username: client_config.username.clone(),
204            password: client_config.password.clone(),
205            retry_config: client_config.retry_config.clone(),
206        })
207    }
208
209    async fn concurrent_req<'a, T>(
210        &'a self,
211        method: Method,
212        path: &'a [&'a (impl AsRef<str> + ?Sized + Debug + ToString)],
213    ) -> SrResult<T>
214    where
215        T: DeserializeOwned + Send + Sync + 'static,
216    {
217        let mut fut_req = Vec::with_capacity(self.url.len());
218        let mut errs = Vec::with_capacity(self.url.len());
219        let ctx = Arc::new(SchemaRegistryCtx {
220            username: self.username.clone(),
221            password: self.password.clone(),
222            client: self.inner.clone(),
223            path: path.iter().map(|p| p.to_string()).collect_vec(),
224        });
225        tracing::debug!("retry config: {:?}", self.retry_config);
226
227        let retry_strategy = ExponentialBackoff::from_millis(self.retry_config.backoff_duration_ms)
228            .factor(self.retry_config.backoff_factor)
229            .max_delay(Duration::from_secs(self.retry_config.max_delay_sec as u64))
230            .take(self.retry_config.retries_max)
231            .map(jitter);
232
233        for url in &self.url {
234            let url_clone = url.clone();
235            let ctx_clone = ctx.clone();
236            let method_clone = method.clone();
237
238            let retry_future = Retry::spawn(retry_strategy.clone(), move || {
239                let ctx = ctx_clone.clone();
240                let url = url_clone.clone();
241                let method = method_clone.clone();
242                async move { req_inner(ctx, url, method).await }
243            });
244
245            fut_req.push(tokio::spawn(retry_future));
246        }
247
248        while !fut_req.is_empty() {
249            let (result, _index, remaining) = select_all(fut_req).await;
250            match result {
251                Ok(Ok(res)) => {
252                    let _ = remaining.iter().map(|ele| ele.abort());
253                    return Ok(res);
254                }
255                Ok(Err(e)) => errs.push(itertools::Either::Left(e)),
256                Err(e) => errs.push(itertools::Either::Right(e)),
257            }
258            fut_req = remaining;
259        }
260
261        Err(ConcurrentRequestError {
262            errs,
263            context: format!("req path {:?}, urls {}", path, self.url.iter().join(" ")),
264        })
265    }
266
267    /// get schema by id
268    pub async fn get_schema_by_id(&self, id: i32) -> SrResult<ConfluentSchema> {
269        let res: GetByIdResp = self
270            .concurrent_req(Method::GET, &["schemas", "ids", &id.to_string()])
271            .await?;
272        Ok(ConfluentSchema {
273            id,
274            content: res.schema,
275        })
276    }
277
278    /// get the latest schema of the subject
279    pub async fn get_schema_by_subject(&self, subject: &str) -> SrResult<ConfluentSchema> {
280        self.get_subject(subject).await.map(|s| s.schema)
281    }
282
283    // used for connection validate, just check if request is ok
284    pub async fn validate_connection(&self) -> SrResult<()> {
285        #[derive(Debug, Deserialize)]
286        struct GetConfigResp {
287            #[serde(rename = "compatibilityLevel")]
288            _compatibility_level: String,
289        }
290
291        let _: GetConfigResp = self.concurrent_req(Method::GET, &["config"]).await?;
292        Ok(())
293    }
294
295    /// get the latest version of the subject
296    pub async fn get_subject(&self, subject: &str) -> SrResult<Subject> {
297        let res: GetBySubjectResp = self
298            .concurrent_req(Method::GET, &["subjects", subject, "versions", "latest"])
299            .await?;
300        tracing::debug!("update schema: {:?}", res);
301        Ok(Subject {
302            schema: ConfluentSchema {
303                id: res.id,
304                content: res.schema,
305            },
306            version: res.version,
307            name: res.subject,
308        })
309    }
310
311    /// get the latest version of the subject and all it's references(deps)
312    pub async fn get_subject_and_references(
313        &self,
314        subject: &str,
315    ) -> SrResult<(Subject, Vec<Subject>)> {
316        let mut subjects = vec![];
317        let mut visited = HashSet::new();
318        let mut queue = vec![(subject.to_owned(), "latest".to_owned())];
319        // use bfs to get all references
320        while let Some((subject, version)) = queue.pop() {
321            let res: GetBySubjectResp = self
322                .concurrent_req(Method::GET, &["subjects", &subject, "versions", &version])
323                .await?;
324            let ref_subject = Subject {
325                schema: ConfluentSchema {
326                    id: res.id,
327                    content: res.schema,
328                },
329                version: res.version,
330                name: res.subject.clone(),
331            };
332            subjects.push(ref_subject);
333            visited.insert(res.subject);
334            queue.extend(
335                res.references
336                    .into_iter()
337                    .filter(|r| !visited.contains(&r.subject))
338                    .map(|r| (r.subject, r.version.to_string())),
339            );
340        }
341        let origin_subject = subjects.remove(0);
342
343        Ok((origin_subject, subjects))
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[tokio::test]
352    #[ignore]
353    async fn test_get_subject() {
354        let url = Url::parse("http://localhost:8081").unwrap();
355        let client = Client::new(
356            vec![url],
357            &SchemaRegistryConfig {
358                username: None,
359                password: None,
360                ca_pem_path: None,
361                retry_config: SchemaRegistryRetryConfig::default(),
362            },
363        )
364        .unwrap();
365        let subject = client
366            .get_subject_and_references("proto_c_bin-value")
367            .await
368            .unwrap();
369        println!("{:?}", subject);
370    }
371}