risingwave_connector/parser/protobuf/
parser.rs1use std::collections::HashSet;
16
17use anyhow::Context;
18use prost_reflect::{DescriptorPool, DynamicMessage, FileDescriptor, MessageDescriptor};
19use risingwave_common::catalog::Field;
20use risingwave_common::{bail, try_match_expand};
21use risingwave_connector_codec::decoder::protobuf::ProtobufAccess;
22pub use risingwave_connector_codec::decoder::protobuf::parser::{PROTOBUF_MESSAGES_AS_JSONB, *};
23
24use crate::error::ConnectorResult;
25use crate::parser::unified::AccessImpl;
26use crate::parser::utils::bytes_from_url;
27use crate::parser::{AccessBuilder, EncodingProperties, SchemaLocation};
28use crate::schema::schema_registry::{Client, WireFormatError, extract_schema_id, handle_sr_list};
29use crate::schema::{
30 ConfluentSchemaLoader, InvalidOptionError, SchemaLoader, bail_invalid_option_error,
31};
32
33#[derive(Debug)]
34pub struct ProtobufAccessBuilder {
35 wire_type: WireType,
36 message_descriptor: MessageDescriptor,
37
38 messages_as_jsonb: HashSet<String>,
41}
42
43impl AccessBuilder for ProtobufAccessBuilder {
44 #[allow(clippy::unused_async)]
45 async fn generate_accessor(
46 &mut self,
47 payload: Vec<u8>,
48 _: &crate::source::SourceMeta,
49 ) -> ConnectorResult<AccessImpl<'_>> {
50 let payload = match self.wire_type {
51 WireType::Confluent => resolve_pb_header(&payload)?,
52 WireType::None => &payload,
53 };
54
55 let message = DynamicMessage::decode(self.message_descriptor.clone(), payload)
56 .context("failed to parse message")?;
57
58 Ok(AccessImpl::Protobuf(ProtobufAccess::new(
59 message,
60 &self.messages_as_jsonb,
61 )))
62 }
63}
64
65impl ProtobufAccessBuilder {
66 pub fn new(config: ProtobufParserConfig) -> ConnectorResult<Self> {
67 let ProtobufParserConfig {
68 wire_type,
69 message_descriptor,
70 messages_as_jsonb,
71 } = config;
72
73 Ok(Self {
74 wire_type,
75 message_descriptor,
76 messages_as_jsonb,
77 })
78 }
79}
80
81#[derive(Debug, Clone)]
82enum WireType {
83 None,
84 Confluent,
85 }
88
89impl TryFrom<&SchemaLocation> for WireType {
90 type Error = InvalidOptionError;
91
92 fn try_from(value: &SchemaLocation) -> Result<Self, Self::Error> {
93 match value {
94 SchemaLocation::File { .. } => Ok(Self::None),
95 SchemaLocation::Confluent { .. } => Ok(Self::Confluent),
96 SchemaLocation::Glue { .. } => bail_invalid_option_error!(
97 "encode protobuf from aws glue schema registry not supported yet"
98 ),
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
104pub struct ProtobufParserConfig {
105 wire_type: WireType,
106 pub(crate) message_descriptor: MessageDescriptor,
107 messages_as_jsonb: HashSet<String>,
108}
109
110impl ProtobufParserConfig {
111 pub async fn new(encoding_properties: EncodingProperties) -> ConnectorResult<Self> {
112 let protobuf_config = try_match_expand!(encoding_properties, EncodingProperties::Protobuf)?;
113 let message_name = &protobuf_config.message_name;
114
115 let wire_type = (&protobuf_config.schema_location).try_into()?;
116 if protobuf_config.key_message_name.is_some() {
117 bail!("protobuf key is not supported");
119 }
120 let pool = match protobuf_config.schema_location {
121 SchemaLocation::Confluent {
122 urls,
123 client_config,
124 name_strategy,
125 topic,
126 } => {
127 let url = handle_sr_list(urls.as_str())?;
128 let client = Client::new(url, &client_config)?;
129 let loader = SchemaLoader::Confluent(ConfluentSchemaLoader {
130 client,
131 name_strategy,
132 topic,
133 key_record_name: None,
134 val_record_name: Some(message_name.clone()),
135 });
136 let (_schema_id, root_file_descriptor) = loader
137 .load_val_schema::<FileDescriptor>()
138 .await
139 .context("load schema failed")?;
140 root_file_descriptor.parent_pool().clone()
141 }
142 SchemaLocation::File {
143 url,
144 aws_auth_props,
145 } => {
146 let url = handle_sr_list(url.as_str())?;
147 let url = url.first().unwrap();
148 let schema_bytes = bytes_from_url(url, aws_auth_props.as_ref()).await?;
149 DescriptorPool::decode(schema_bytes.as_slice())
150 .with_context(|| format!("cannot build descriptor pool from schema `{url}`"))?
151 }
152 SchemaLocation::Glue { .. } => bail_invalid_option_error!(
153 "encode protobuf from aws glue schema registry not supported yet"
154 ),
155 };
156
157 let message_descriptor = pool
158 .get_message_by_name(message_name)
159 .with_context(|| format!("cannot find message `{message_name}` in schema"))?;
160
161 Ok(Self {
162 message_descriptor,
163 wire_type,
164 messages_as_jsonb: protobuf_config.messages_as_jsonb,
165 })
166 }
167
168 pub fn map_to_columns(&self) -> ConnectorResult<Vec<Field>> {
170 pb_schema_to_fields(&self.message_descriptor, &self.messages_as_jsonb).map_err(|e| e.into())
171 }
172}
173
174fn decode_varint_zigzag(buffer: &[u8]) -> ConnectorResult<(i32, usize)> {
177 let mut value = 0u32;
179 let mut shift = 0;
180 let mut len = 0usize;
181
182 for &byte in buffer {
183 len += 1;
184 if len > 5 {
186 break;
187 }
188 let byte_ext = byte as u32;
190 value |= (byte_ext & 0x7F) << shift;
193 if byte_ext & 0x80 == 0 {
194 return Ok((((value >> 1) as i32) ^ -((value & 1) as i32), len));
195 }
196
197 shift += 7;
198 }
199
200 Err(WireFormatError::ParseMessageIndexes.into())
201}
202
203pub(crate) fn resolve_pb_header(payload: &[u8]) -> ConnectorResult<&[u8]> {
208 let (_, remained) = extract_schema_id(payload)?;
211 match remained.first() {
215 Some(0) => Ok(&remained[1..]),
216 Some(_) => {
217 let (index_len, mut offset) = decode_varint_zigzag(remained)?;
218 for _ in 0..index_len {
219 offset += decode_varint_zigzag(&remained[offset..])?.1;
220 }
221 Ok(&remained[offset..])
222 }
223 None => bail!("The proto payload is empty"),
224 }
225}
226
227#[cfg(test)]
228mod test {
229 use super::*;
230
231 #[test]
232 fn test_decode_varint_zigzag() {
233 let buffer = vec![0x02];
235 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
236 assert_eq!(value, 1);
237 assert_eq!(len, 1);
238
239 let buffer = vec![0x01];
241 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
242 assert_eq!(value, -1);
243 assert_eq!(len, 1);
244
245 let buffer = vec![0x9E, 0x03];
247 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
248 assert_eq!(value, 207);
249 assert_eq!(len, 2);
250
251 let buffer = vec![0xBF, 0x07];
253 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
254 assert_eq!(value, -480);
255 assert_eq!(len, 2);
256
257 let buffer = vec![0xFE, 0xFF, 0xFF, 0xFF, 0x0F];
259 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
260 assert_eq!(value, i32::MAX);
261 assert_eq!(len, 5);
262
263 let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0x0F];
265 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
266 assert_eq!(value, i32::MIN);
267 assert_eq!(len, 5);
268
269 let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0x7F];
271 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
272 assert_eq!(value, i32::MIN);
273 assert_eq!(len, 5);
274
275 let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
277 let result = decode_varint_zigzag(&buffer);
278 assert!(result.is_err());
279 }
280}