risingwave_connector/schema/schema_registry/
util.rs1use std::fmt::Debug;
16use std::sync::Arc;
17
18use reqwest::Method;
19use serde::de::DeserializeOwned;
20use serde_derive::Deserialize;
21use url::{ParseError, Url};
22
23use crate::schema::{InvalidOptionError, bail_invalid_option_error};
24
25pub fn handle_sr_list(addr: &str) -> Result<Vec<Url>, InvalidOptionError> {
26 let segment = addr.split(',').collect::<Vec<&str>>();
27 let mut errs: Vec<ParseError> = Vec::with_capacity(segment.len());
28 let mut urls = Vec::with_capacity(segment.len());
29 for ele in segment {
30 match ele.parse::<Url>() {
31 Ok(url) => urls.push(url),
32 Err(e) => errs.push(e),
33 }
34 }
35 if urls.is_empty() {
36 bail_invalid_option_error!("no valid url provided, errs: {errs:?}");
37 }
38 tracing::debug!(
39 "schema registry client will use url {:?} to connect, the rest failed because: {:?}",
40 urls,
41 errs
42 );
43 Ok(urls)
44}
45
46#[derive(Debug, thiserror::Error)]
47pub enum WireFormatError {
48 #[error("fail to match a magic byte of 0")]
49 NoMagic,
50 #[error("fail to read 4-byte schema ID")]
51 NoSchemaId,
52 #[error("failed to parse message indexes")]
53 ParseMessageIndexes,
54}
55
56pub(crate) fn extract_schema_id(payload: &[u8]) -> Result<(i32, &[u8]), WireFormatError> {
66 use byteorder::{BigEndian, ReadBytesExt as _};
67
68 let mut cursor = payload;
69 if !cursor.read_u8().is_ok_and(|magic| magic == 0) {
70 return Err(WireFormatError::NoMagic);
71 }
72
73 let schema_id = cursor
74 .read_i32::<BigEndian>()
75 .map_err(|_| WireFormatError::NoSchemaId)?;
76
77 Ok((schema_id, cursor))
78}
79
80pub(crate) struct SchemaRegistryCtx {
81 pub username: Option<String>,
82 pub password: Option<String>,
83 pub client: reqwest::Client,
84 pub path: Vec<String>,
85}
86
87#[derive(Debug, thiserror::Error)]
88pub enum RequestError {
89 #[error("confluent registry send req error: {0}")]
90 Send(#[source] reqwest::Error),
91 #[error("confluent registry parse resp error: {0}")]
92 Json(#[source] reqwest::Error),
93 #[error(transparent)]
94 Unsuccessful(ErrorResp),
95}
96
97pub(crate) async fn req_inner<T>(
98 ctx: Arc<SchemaRegistryCtx>,
99 mut url: Url,
100 method: Method,
101) -> Result<T, RequestError>
102where
103 T: DeserializeOwned + Send + Sync + 'static,
104{
105 url.path_segments_mut()
106 .expect("constructor validated URL can be a base")
107 .clear()
108 .extend(&ctx.path);
109 tracing::debug!("request to url: {}, method {}", &url, &method);
110 let mut request_builder = ctx.client.request(method, url);
111
112 if let Some(ref username) = ctx.username {
113 request_builder = request_builder.basic_auth(username, ctx.password.as_ref());
114 }
115 request(request_builder).await
116}
117
118async fn request<T>(req: reqwest::RequestBuilder) -> Result<T, RequestError>
119where
120 T: DeserializeOwned,
121{
122 let res = req.send().await.map_err(RequestError::Send)?;
123 let status = res.status();
124 if status.is_success() {
125 res.json().await.map_err(RequestError::Json)
126 } else {
127 let res = res.json().await.map_err(RequestError::Json)?;
128 Err(RequestError::Unsuccessful(res))
129 }
130}
131
132#[derive(Debug, Eq, PartialEq)]
134pub struct ConfluentSchema {
135 pub id: i32,
137 pub content: String,
139}
140
141#[derive(Debug, Eq, PartialEq)]
143pub struct Subject {
144 pub version: i32,
146 pub name: String,
148 pub schema: ConfluentSchema,
150}
151
152#[derive(Debug, Deserialize)]
155pub struct SchemaReference {
156 #[allow(dead_code)]
158 pub name: String,
159 pub subject: String,
161 pub version: i32,
163}
164
165#[derive(Debug, Deserialize)]
166pub struct GetByIdResp {
167 pub schema: String,
168}
169
170#[derive(Debug, Deserialize)]
171pub struct GetBySubjectResp {
172 pub id: i32,
173 pub schema: String,
174 pub version: i32,
175 pub subject: String,
176 #[serde(default)]
178 pub references: Vec<SchemaReference>,
179}
180
181#[derive(Debug, Deserialize, thiserror::Error)]
183#[error("confluent schema registry error {error_code}: {message}")]
184pub struct ErrorResp {
185 error_code: i32,
186 message: String,
187}
188
189#[cfg(test)]
190mod test {
191 use super::super::handle_sr_list;
192
193 #[test]
194 fn test_handle_sr_list() {
195 let addr1 = "http://localhost:8081".to_owned();
196 assert_eq!(
197 handle_sr_list(&addr1).unwrap(),
198 vec!["http://localhost:8081".parse().unwrap()]
199 );
200
201 let addr2 = "http://localhost:8081,http://localhost:8082".to_owned();
202 assert_eq!(
203 handle_sr_list(&addr2).unwrap(),
204 vec![
205 "http://localhost:8081".parse().unwrap(),
206 "http://localhost:8082".parse().unwrap()
207 ]
208 );
209
210 let fail_addr = "http://localhost:8081,12345".to_owned();
211 assert_eq!(
212 handle_sr_list(&fail_addr).unwrap(),
213 vec!["http://localhost:8081".parse().unwrap(),]
214 );
215
216 let all_fail_addr = "54321,12345".to_owned();
217 assert!(handle_sr_list(&all_fail_addr).is_err());
218 }
219}