risingwave_connector/schema/schema_registry/
client.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::fmt::Debug;
17use std::sync::Arc;
18use std::time::Duration;
19
20use futures::future::select_all;
21use itertools::Itertools;
22use reqwest::{Method, Url};
23use serde::Deserialize;
24use serde::de::DeserializeOwned;
25use thiserror_ext::AsReport as _;
26use tokio_retry::Retry;
27use tokio_retry::strategy::{ExponentialBackoff, jitter};
28
29use super::util::*;
30use crate::connector_common::ConfluentSchemaRegistryConnection;
31use crate::schema::{InvalidOptionError, invalid_option_error};
32use crate::with_options::Get;
33
34pub const SCHEMA_REGISTRY_USERNAME: &str = "schema.registry.username";
35pub const SCHEMA_REGISTRY_PASSWORD: &str = "schema.registry.password";
36
37pub const SCHEMA_REGISTRY_MAX_DELAY_KEY: &str = "schema.registry.max.delay.sec";
38pub const SCHEMA_REGISTRY_BACKOFF_DURATION_KEY: &str = "schema.registry.backoff.duration.ms";
39pub const SCHEMA_REGISTRY_BACKOFF_FACTOR_KEY: &str = "schema.registry.backoff.factor";
40pub const SCHEMA_REGISTRY_RETRIES_MAX_KEY: &str = "schema.registry.retries.max";
41
42const DEFAULT_MAX_DELAY_SEC: u32 = 3;
43const DEFAULT_BACKOFF_DURATION_MS: u64 = 100;
44const DEFAULT_BACKOFF_FACTOR: u64 = 2;
45const DEFAULT_RETRIES_MAX: usize = 3;
46
47#[derive(Debug, Clone)]
48struct SchemaRegistryRetryConfig {
49    pub max_delay_sec: u32,
50    pub backoff_duration_ms: u64,
51    pub backoff_factor: u64,
52    pub retries_max: usize,
53}
54
55impl Default for SchemaRegistryRetryConfig {
56    fn default() -> Self {
57        Self {
58            max_delay_sec: DEFAULT_MAX_DELAY_SEC,
59            backoff_duration_ms: DEFAULT_BACKOFF_DURATION_MS,
60            backoff_factor: DEFAULT_BACKOFF_FACTOR,
61            retries_max: DEFAULT_RETRIES_MAX,
62        }
63    }
64}
65
66#[derive(Debug, Clone, Default)]
67pub struct SchemaRegistryConfig {
68    username: Option<String>,
69    password: Option<String>,
70
71    retry_config: SchemaRegistryRetryConfig,
72}
73
74impl<T: Get> From<&T> for SchemaRegistryConfig {
75    fn from(props: &T) -> Self {
76        SchemaRegistryConfig {
77            username: props.get(SCHEMA_REGISTRY_USERNAME).cloned(),
78            password: props.get(SCHEMA_REGISTRY_PASSWORD).cloned(),
79
80            retry_config: SchemaRegistryRetryConfig {
81                max_delay_sec: props
82                    .get(SCHEMA_REGISTRY_MAX_DELAY_KEY)
83                    .and_then(|v| v.parse::<u32>().ok())
84                    .unwrap_or(DEFAULT_MAX_DELAY_SEC),
85                backoff_duration_ms: props
86                    .get(SCHEMA_REGISTRY_BACKOFF_DURATION_KEY)
87                    .and_then(|v| v.parse::<u64>().ok())
88                    .unwrap_or(DEFAULT_BACKOFF_DURATION_MS),
89                backoff_factor: props
90                    .get(SCHEMA_REGISTRY_BACKOFF_FACTOR_KEY)
91                    .and_then(|v| v.parse::<u64>().ok())
92                    .unwrap_or(DEFAULT_BACKOFF_FACTOR),
93                retries_max: props
94                    .get(SCHEMA_REGISTRY_RETRIES_MAX_KEY)
95                    .and_then(|v| v.parse::<usize>().ok())
96                    .unwrap_or(DEFAULT_RETRIES_MAX),
97            },
98        }
99    }
100}
101
102/// An client for communication with schema registry
103#[derive(Debug)]
104pub struct Client {
105    inner: reqwest::Client,
106    url: Vec<Url>,
107    username: Option<String>,
108    password: Option<String>,
109
110    retry_config: SchemaRegistryRetryConfig,
111}
112
113#[derive(Debug, thiserror::Error)]
114#[error("all request confluent registry all timeout, {context}\n{}", errs.iter().map(|e| format!("\t{}", e.as_report())).join("\n"))]
115pub struct ConcurrentRequestError {
116    errs: Vec<itertools::Either<RequestError, tokio::task::JoinError>>,
117    context: String,
118}
119
120type SrResult<T> = Result<T, ConcurrentRequestError>;
121
122impl TryFrom<&ConfluentSchemaRegistryConnection> for Client {
123    type Error = InvalidOptionError;
124
125    fn try_from(value: &ConfluentSchemaRegistryConnection) -> Result<Self, Self::Error> {
126        let urls = handle_sr_list(value.url.as_str())?;
127
128        Client::new(
129            urls,
130            &SchemaRegistryConfig {
131                username: value.username.clone(),
132                password: value.password.clone(),
133                ..Default::default()
134            },
135        )
136    }
137}
138
139impl Client {
140    pub(crate) fn new(
141        url: Vec<Url>,
142        client_config: &SchemaRegistryConfig,
143    ) -> Result<Self, InvalidOptionError> {
144        let valid_urls = url
145            .iter()
146            .map(|url| (url.cannot_be_a_base(), url))
147            .filter(|(x, _)| !*x)
148            .map(|(_, url)| url.clone())
149            .collect_vec();
150        if valid_urls.is_empty() {
151            return Err(invalid_option_error!("non-base: {}", url.iter().join(" ")));
152        } else {
153            tracing::debug!(
154                "schema registry client will use url {:?} to connect",
155                valid_urls
156            );
157        }
158
159        // `unwrap` as the builder is not affected by any input right now
160        let inner = reqwest::Client::builder().build().unwrap();
161
162        Ok(Client {
163            inner,
164            url: valid_urls,
165            username: client_config.username.clone(),
166            password: client_config.password.clone(),
167            retry_config: client_config.retry_config.clone(),
168        })
169    }
170
171    async fn concurrent_req<'a, T>(
172        &'a self,
173        method: Method,
174        path: &'a [&'a (impl AsRef<str> + ?Sized + Debug + ToString)],
175    ) -> SrResult<T>
176    where
177        T: DeserializeOwned + Send + Sync + 'static,
178    {
179        let mut fut_req = Vec::with_capacity(self.url.len());
180        let mut errs = Vec::with_capacity(self.url.len());
181        let ctx = Arc::new(SchemaRegistryCtx {
182            username: self.username.clone(),
183            password: self.password.clone(),
184            client: self.inner.clone(),
185            path: path.iter().map(|p| p.to_string()).collect_vec(),
186        });
187        tracing::debug!("retry config: {:?}", self.retry_config);
188
189        let retry_strategy = ExponentialBackoff::from_millis(self.retry_config.backoff_duration_ms)
190            .factor(self.retry_config.backoff_factor)
191            .max_delay(Duration::from_secs(self.retry_config.max_delay_sec as u64))
192            .take(self.retry_config.retries_max)
193            .map(jitter);
194
195        for url in &self.url {
196            let url_clone = url.clone();
197            let ctx_clone = ctx.clone();
198            let method_clone = method.clone();
199
200            let retry_future = Retry::spawn(retry_strategy.clone(), move || {
201                let ctx = ctx_clone.clone();
202                let url = url_clone.clone();
203                let method = method_clone.clone();
204                async move { req_inner(ctx, url, method).await }
205            });
206
207            fut_req.push(tokio::spawn(retry_future));
208        }
209
210        while !fut_req.is_empty() {
211            let (result, _index, remaining) = select_all(fut_req).await;
212            match result {
213                Ok(Ok(res)) => {
214                    let _ = remaining.iter().map(|ele| ele.abort());
215                    return Ok(res);
216                }
217                Ok(Err(e)) => errs.push(itertools::Either::Left(e)),
218                Err(e) => errs.push(itertools::Either::Right(e)),
219            }
220            fut_req = remaining;
221        }
222
223        Err(ConcurrentRequestError {
224            errs,
225            context: format!("req path {:?}, urls {}", path, self.url.iter().join(" ")),
226        })
227    }
228
229    /// get schema by id
230    pub async fn get_schema_by_id(&self, id: i32) -> SrResult<ConfluentSchema> {
231        let res: GetByIdResp = self
232            .concurrent_req(Method::GET, &["schemas", "ids", &id.to_string()])
233            .await?;
234        Ok(ConfluentSchema {
235            id,
236            content: res.schema,
237        })
238    }
239
240    /// get the latest schema of the subject
241    pub async fn get_schema_by_subject(&self, subject: &str) -> SrResult<ConfluentSchema> {
242        self.get_subject(subject).await.map(|s| s.schema)
243    }
244
245    // used for connection validate, just check if request is ok
246    pub async fn validate_connection(&self) -> SrResult<()> {
247        #[derive(Debug, Deserialize)]
248        struct GetConfigResp {
249            #[serde(rename = "compatibilityLevel")]
250            _compatibility_level: String,
251        }
252
253        let _: GetConfigResp = self.concurrent_req(Method::GET, &["config"]).await?;
254        Ok(())
255    }
256
257    /// get the latest version of the subject
258    pub async fn get_subject(&self, subject: &str) -> SrResult<Subject> {
259        let res: GetBySubjectResp = self
260            .concurrent_req(Method::GET, &["subjects", subject, "versions", "latest"])
261            .await?;
262        tracing::debug!("update schema: {:?}", res);
263        Ok(Subject {
264            schema: ConfluentSchema {
265                id: res.id,
266                content: res.schema,
267            },
268            version: res.version,
269            name: res.subject,
270        })
271    }
272
273    /// get the latest version of the subject and all it's references(deps)
274    pub async fn get_subject_and_references(
275        &self,
276        subject: &str,
277    ) -> SrResult<(Subject, Vec<Subject>)> {
278        let mut subjects = vec![];
279        let mut visited = HashSet::new();
280        let mut queue = vec![(subject.to_owned(), "latest".to_owned())];
281        // use bfs to get all references
282        while let Some((subject, version)) = queue.pop() {
283            let res: GetBySubjectResp = self
284                .concurrent_req(Method::GET, &["subjects", &subject, "versions", &version])
285                .await?;
286            let ref_subject = Subject {
287                schema: ConfluentSchema {
288                    id: res.id,
289                    content: res.schema,
290                },
291                version: res.version,
292                name: res.subject.clone(),
293            };
294            subjects.push(ref_subject);
295            visited.insert(res.subject);
296            queue.extend(
297                res.references
298                    .into_iter()
299                    .filter(|r| !visited.contains(&r.subject))
300                    .map(|r| (r.subject, r.version.to_string())),
301            );
302        }
303        let origin_subject = subjects.remove(0);
304
305        Ok((origin_subject, subjects))
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[tokio::test]
314    #[ignore]
315    async fn test_get_subject() {
316        let url = Url::parse("http://localhost:8081").unwrap();
317        let client = Client::new(
318            vec![url],
319            &SchemaRegistryConfig {
320                username: None,
321                password: None,
322                retry_config: SchemaRegistryRetryConfig::default(),
323            },
324        )
325        .unwrap();
326        let subject = client
327            .get_subject_and_references("proto_c_bin-value")
328            .await
329            .unwrap();
330        println!("{:?}", subject);
331    }
332}