risingwave_connector/schema/schema_registry/
util.rs1use std::fmt::Debug;
16use std::sync::Arc;
17
18use reqwest::Method;
19use serde::de::DeserializeOwned;
20use serde_derive::Deserialize;
21use url::{ParseError, Url};
22
23use crate::schema::{InvalidOptionError, bail_invalid_option_error};
24
25pub fn handle_sr_list(addr: &str) -> Result<Vec<Url>, InvalidOptionError> {
26 let segment = addr.split(',').collect::<Vec<&str>>();
27 let mut errs: Vec<ParseError> = Vec::with_capacity(segment.len());
28 let mut urls = Vec::with_capacity(segment.len());
29 for ele in segment {
30 match ele.parse::<Url>() {
31 Ok(url) => urls.push(url),
32 Err(e) => errs.push(e),
33 }
34 }
35 if urls.is_empty() {
36 bail_invalid_option_error!("no valid url provided, errs: {errs:?}");
37 }
38 tracing::debug!(
39 "schema registry client will use url {:?} to connect, the rest failed because: {:?}",
40 urls,
41 errs
42 );
43 Ok(urls)
44}
45
46#[derive(Debug, thiserror::Error)]
47pub enum WireFormatError {
48 #[error("fail to match a magic byte of 0")]
49 NoMagic,
50 #[error("fail to read 4-byte schema ID")]
51 NoSchemaId,
52 #[error("failed to parse message indexes")]
53 ParseMessageIndexes,
54}
55
56pub(crate) fn extract_schema_id(payload: &[u8]) -> Result<(i32, &[u8]), WireFormatError> {
66 use byteorder::{BigEndian, ReadBytesExt as _};
67
68 let mut cursor = payload;
69 if !cursor.read_u8().is_ok_and(|magic| magic == 0) {
70 return Err(WireFormatError::NoMagic);
71 }
72
73 let schema_id = cursor
74 .read_i32::<BigEndian>()
75 .map_err(|_| WireFormatError::NoSchemaId)?;
76
77 Ok((schema_id, cursor))
78}
79
80pub(crate) struct SchemaRegistryCtx {
81 pub username: Option<String>,
82 pub password: Option<String>,
83 pub client: reqwest::Client,
84 pub path: Vec<String>,
85}
86
87#[derive(Debug, thiserror::Error)]
88pub enum RequestError {
89 #[error("confluent registry send req error: {0}")]
90 Send(#[source] reqwest::Error),
91 #[error("confluent registry parse resp error: {0}")]
92 Json(#[source] reqwest::Error),
93 #[error(transparent)]
94 Unsuccessful(ErrorResp),
95}
96
97pub(crate) async fn req_inner<T>(
98 ctx: Arc<SchemaRegistryCtx>,
99 mut url: Url,
100 method: Method,
101) -> Result<T, RequestError>
102where
103 T: DeserializeOwned + Send + Sync + 'static,
104{
105 url.path_segments_mut()
106 .expect("constructor validated URL can be a base")
107 .extend(&ctx.path);
108 tracing::debug!("request to url: {}, method {}", &url, &method);
109 let mut request_builder = ctx.client.request(method, url);
110
111 if let Some(ref username) = ctx.username {
112 request_builder = request_builder.basic_auth(username, ctx.password.as_ref());
113 }
114 request(request_builder).await
115}
116
117async fn request<T>(req: reqwest::RequestBuilder) -> Result<T, RequestError>
118where
119 T: DeserializeOwned,
120{
121 let res = req.send().await.map_err(RequestError::Send)?;
122 let status = res.status();
123 if status.is_success() {
124 res.json().await.map_err(RequestError::Json)
125 } else {
126 let res = res.json().await.map_err(RequestError::Json)?;
127 Err(RequestError::Unsuccessful(res))
128 }
129}
130
131#[derive(Debug, Eq, PartialEq)]
133pub struct ConfluentSchema {
134 pub id: i32,
136 pub content: String,
138}
139
140#[derive(Debug, Eq, PartialEq)]
142pub struct Subject {
143 pub version: i32,
145 pub name: String,
147 pub schema: ConfluentSchema,
149}
150
151#[derive(Debug, Deserialize)]
154pub struct SchemaReference {
155 #[allow(dead_code)]
157 pub name: String,
158 pub subject: String,
160 pub version: i32,
162}
163
164#[derive(Debug, Deserialize)]
165pub struct GetByIdResp {
166 pub schema: String,
167}
168
169#[derive(Debug, Deserialize)]
170pub struct GetBySubjectResp {
171 pub id: i32,
172 pub schema: String,
173 pub version: i32,
174 pub subject: String,
175 #[serde(default)]
177 pub references: Vec<SchemaReference>,
178}
179
180#[derive(Debug, Deserialize, thiserror::Error)]
182#[error("confluent schema registry error {error_code}: {message}")]
183pub struct ErrorResp {
184 error_code: i32,
185 message: String,
186}
187
188#[cfg(test)]
189mod test {
190 use super::super::handle_sr_list;
191
192 #[test]
193 fn test_handle_sr_list() {
194 let addr1 = "http://localhost:8081".to_owned();
195 assert_eq!(
196 handle_sr_list(&addr1).unwrap(),
197 vec!["http://localhost:8081".parse().unwrap()]
198 );
199
200 let addr2 = "http://localhost:8081,http://localhost:8082".to_owned();
201 assert_eq!(
202 handle_sr_list(&addr2).unwrap(),
203 vec![
204 "http://localhost:8081".parse().unwrap(),
205 "http://localhost:8082".parse().unwrap()
206 ]
207 );
208
209 let fail_addr = "http://localhost:8081,12345".to_owned();
210 assert_eq!(
211 handle_sr_list(&fail_addr).unwrap(),
212 vec!["http://localhost:8081".parse().unwrap(),]
213 );
214
215 let all_fail_addr = "54321,12345".to_owned();
216 assert!(handle_sr_list(&all_fail_addr).is_err());
217 }
218}